mct-nightly 1.11.0.20240320.400__py3-none-any.whl → 1.11.0.20240322.404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/METADATA +17 -9
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/RECORD +152 -152
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +1 -1
- model_compression_toolkit/core/__init__.py +3 -3
- model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
- model_compression_toolkit/core/common/data_loader.py +3 -3
- model_compression_toolkit/core/common/graph/base_graph.py +10 -13
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/graph/edge.py +2 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/network_editors/actions.py +3 -3
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
- model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
- model_compression_toolkit/core/common/pruning/pruner.py +7 -7
- model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
- model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
- model_compression_toolkit/core/common/user_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
- model_compression_toolkit/core/keras/keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
- model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
- model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
- model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
- model_compression_toolkit/core/pytorch/utils.py +3 -2
- model_compression_toolkit/core/runner.py +43 -42
- model_compression_toolkit/data_generation/common/data_generation.py +18 -18
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
- model_compression_toolkit/gptq/common/gptq_training.py +14 -12
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/logger.py +1 -13
- model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
- model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
- model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
- model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
- model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240320.400.dist-info → mct_nightly-1.11.0.20240322.404.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
|
@@ -19,27 +19,28 @@ from tqdm import tqdm
|
|
|
19
19
|
from typing import Dict, List, Tuple, Callable
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.logger import Logger
|
|
22
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
22
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
|
|
24
24
|
|
|
25
25
|
# Limit ILP solver runtime in seconds
|
|
26
26
|
SOLVER_TIME_LIMIT = 60
|
|
27
27
|
|
|
28
|
+
|
|
28
29
|
def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
29
|
-
|
|
30
|
+
target_resource_utilization: ResourceUtilization = None) -> List[int]:
|
|
30
31
|
"""
|
|
31
32
|
Searching and returning a mixed-precision configuration using an ILP optimization solution.
|
|
32
33
|
It first builds a mapping from each layer's index (in the model) to a dictionary that maps the
|
|
33
34
|
bitwidth index to the observed sensitivity of the model when using that bitwidth for that layer.
|
|
34
35
|
Then, it creates a mapping from each node's index (in the graph) to a dictionary
|
|
35
36
|
that maps the bitwidth index to the contribution of configuring this node with this
|
|
36
|
-
bitwidth to the minimal possible
|
|
37
|
+
bitwidth to the minimal possible resource utilization of the model.
|
|
37
38
|
Then, and using these mappings, it builds an LP problem and finds an optimal solution.
|
|
38
39
|
If a solution could not be found, exception is thrown.
|
|
39
40
|
|
|
40
41
|
Args:
|
|
41
42
|
search_manager: MixedPrecisionSearchManager object to be used for problem formalization.
|
|
42
|
-
|
|
43
|
+
target_resource_utilization: Target resource utilization to constrain our LP problem with some resources limitations (like model' weights memory
|
|
43
44
|
consumption).
|
|
44
45
|
|
|
45
46
|
Returns:
|
|
@@ -50,11 +51,11 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
|
50
51
|
# Build a mapping from each layer's index (in the model) to a dictionary that maps the
|
|
51
52
|
# bitwidth index to the observed sensitivity of the model when using that bitwidth for that layer.
|
|
52
53
|
|
|
53
|
-
if
|
|
54
|
-
Logger.critical("
|
|
55
|
-
"
|
|
54
|
+
if target_resource_utilization is None or search_manager is None:
|
|
55
|
+
Logger.critical("Invalid parameters: 'target_resource_utilization' and 'search_manager' must not be 'None' "
|
|
56
|
+
"for mixed-precision search. Ensure valid inputs are provided.")
|
|
56
57
|
|
|
57
|
-
layer_to_metrics_mapping = _build_layer_to_metrics_mapping(search_manager,
|
|
58
|
+
layer_to_metrics_mapping = _build_layer_to_metrics_mapping(search_manager, target_resource_utilization)
|
|
58
59
|
|
|
59
60
|
# Init variables to find their values when solving the lp problem.
|
|
60
61
|
layer_to_indicator_vars_mapping, layer_to_objective_vars_mapping = _init_problem_vars(layer_to_metrics_mapping)
|
|
@@ -63,7 +64,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
|
63
64
|
lp_problem = _formalize_problem(layer_to_indicator_vars_mapping,
|
|
64
65
|
layer_to_metrics_mapping,
|
|
65
66
|
layer_to_objective_vars_mapping,
|
|
66
|
-
|
|
67
|
+
target_resource_utilization,
|
|
67
68
|
search_manager)
|
|
68
69
|
|
|
69
70
|
# Use default PULP solver. Limit runtime in seconds
|
|
@@ -81,7 +82,7 @@ def mp_integer_programming_search(search_manager: MixedPrecisionSearchManager,
|
|
|
81
82
|
in layer_to_indicator_vars_mapping.values()]
|
|
82
83
|
).flatten()
|
|
83
84
|
|
|
84
|
-
if
|
|
85
|
+
if target_resource_utilization.bops < np.inf:
|
|
85
86
|
return search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(config)
|
|
86
87
|
else:
|
|
87
88
|
return config
|
|
@@ -122,7 +123,7 @@ def _init_problem_vars(layer_to_metrics_mapping: Dict[int, Dict[int, float]]) ->
|
|
|
122
123
|
def _formalize_problem(layer_to_indicator_vars_mapping: Dict[int, Dict[int, LpVariable]],
|
|
123
124
|
layer_to_metrics_mapping: Dict[int, Dict[int, float]],
|
|
124
125
|
layer_to_objective_vars_mapping: Dict[int, LpVariable],
|
|
125
|
-
|
|
126
|
+
target_resource_utilization: ResourceUtilization,
|
|
126
127
|
search_manager: MixedPrecisionSearchManager) -> LpProblem:
|
|
127
128
|
"""
|
|
128
129
|
Formalize the LP problem by defining all inequalities that define the solution space.
|
|
@@ -134,8 +135,8 @@ def _formalize_problem(layer_to_indicator_vars_mapping: Dict[int, Dict[int, LpVa
|
|
|
134
135
|
evaluation.
|
|
135
136
|
layer_to_objective_vars_mapping: Dictionary that maps each node's index to a bitwidth variable we find its
|
|
136
137
|
value.
|
|
137
|
-
|
|
138
|
-
search_manager: MixedPrecisionSearchManager object to be used for
|
|
138
|
+
target_resource_utilization: Target resource utilization to reduce our feasible solution space.
|
|
139
|
+
search_manager: MixedPrecisionSearchManager object to be used for resource utilization constraints formalization.
|
|
139
140
|
|
|
140
141
|
Returns:
|
|
141
142
|
The formalized LP problem.
|
|
@@ -155,9 +156,9 @@ def _formalize_problem(layer_to_indicator_vars_mapping: Dict[int, Dict[int, LpVa
|
|
|
155
156
|
lp_problem += lpSum(
|
|
156
157
|
[v for v in layer_to_indicator_vars_mapping[layer].values()]) == 1
|
|
157
158
|
|
|
158
|
-
# Bound the feasible solution space with the desired
|
|
159
|
-
# Creates separate constraints for weights
|
|
160
|
-
if
|
|
159
|
+
# Bound the feasible solution space with the desired resource utilization values.
|
|
160
|
+
# Creates separate constraints for weights utilization and activation utilization.
|
|
161
|
+
if target_resource_utilization is not None:
|
|
161
162
|
indicators = []
|
|
162
163
|
for layer in layer_to_metrics_mapping.keys():
|
|
163
164
|
for _, indicator in layer_to_indicator_vars_mapping[layer].items():
|
|
@@ -166,73 +167,76 @@ def _formalize_problem(layer_to_indicator_vars_mapping: Dict[int, Dict[int, LpVa
|
|
|
166
167
|
indicators_arr = np.array(indicators)
|
|
167
168
|
indicators_matrix = np.diag(indicators_arr)
|
|
168
169
|
|
|
169
|
-
for target,
|
|
170
|
-
if not np.isinf(
|
|
171
|
-
|
|
172
|
-
else search_manager.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
170
|
+
for target, ru_value in target_resource_utilization.get_resource_utilization_dict().items():
|
|
171
|
+
if not np.isinf(ru_value):
|
|
172
|
+
non_conf_ru_vector = None if search_manager.non_conf_ru_dict is None \
|
|
173
|
+
else search_manager.non_conf_ru_dict.get(target)
|
|
174
|
+
_add_set_of_ru_constraints(search_manager=search_manager,
|
|
175
|
+
target=target,
|
|
176
|
+
target_resource_utilization_value=ru_value,
|
|
177
|
+
indicators_matrix=indicators_matrix,
|
|
178
|
+
lp_problem=lp_problem,
|
|
179
|
+
non_conf_ru_vector=non_conf_ru_vector)
|
|
179
180
|
else: # pragma: no cover
|
|
180
|
-
|
|
181
|
-
|
|
181
|
+
Logger.critical("Unable to execute mixed-precision search: 'target_resource_utilization' is None. "
|
|
182
|
+
"A valid 'target_resource_utilization' is required.")
|
|
182
183
|
return lp_problem
|
|
183
184
|
|
|
184
185
|
|
|
185
|
-
def
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
186
|
+
def _add_set_of_ru_constraints(search_manager: MixedPrecisionSearchManager,
|
|
187
|
+
target: RUTarget,
|
|
188
|
+
target_resource_utilization_value: float,
|
|
189
|
+
indicators_matrix: np.ndarray,
|
|
190
|
+
lp_problem: LpProblem,
|
|
191
|
+
non_conf_ru_vector: np.ndarray):
|
|
191
192
|
"""
|
|
192
|
-
Adding a constraint for the Lp problem for the given
|
|
193
|
+
Adding a constraint for the Lp problem for the given target resource utilization.
|
|
193
194
|
The update to the Lp problem object is done inplace.
|
|
194
195
|
|
|
195
196
|
Args:
|
|
196
|
-
search_manager: MixedPrecisionSearchManager object to be used for
|
|
197
|
-
target: A
|
|
198
|
-
|
|
197
|
+
search_manager: MixedPrecisionSearchManager object to be used for resource utilization constraints formalization.
|
|
198
|
+
target: A RUTarget.
|
|
199
|
+
target_resource_utilization_value: Target resource utilization value of the given target resource utilization
|
|
200
|
+
for which the constraint is added.
|
|
199
201
|
indicators_matrix: A diagonal matrix of the Lp problem's indicators.
|
|
200
202
|
lp_problem: An Lp problem object to add constraint to.
|
|
201
|
-
|
|
203
|
+
non_conf_ru_vector: A non-configurable nodes' resource utilization vector.
|
|
202
204
|
|
|
203
205
|
"""
|
|
204
206
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
+
ru_matrix = search_manager.compute_resource_utilization_matrix(target)
|
|
208
|
+
indicated_ru_matrix = np.matmul(ru_matrix, indicators_matrix)
|
|
207
209
|
# Need to re-organize the tensor such that the configurations' axis will be second,
|
|
208
210
|
# and all metric values' axis will come afterword
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
# In order to get the result
|
|
212
|
-
# Each row represents the
|
|
213
|
-
# to a configuration which implied by the set of indicators will have some
|
|
214
|
-
# (and will contribute to the total
|
|
215
|
-
|
|
216
|
-
np.sum(
|
|
217
|
-
search_manager.
|
|
218
|
-
|
|
219
|
-
# search_manager.
|
|
220
|
-
# get aggregated
|
|
221
|
-
if
|
|
222
|
-
|
|
211
|
+
indicated_ru_matrix = np.moveaxis(indicated_ru_matrix, source=len(indicated_ru_matrix.shape) - 1, destination=1)
|
|
212
|
+
|
|
213
|
+
# In order to get the result resource utilization according to a chosen set of indicators, we sum each row in
|
|
214
|
+
# the result matrix. Each row represents the resource utilization values for a specific resource utilization metric,
|
|
215
|
+
# such that only elements corresponding to a configuration which implied by the set of indicators will have some
|
|
216
|
+
# positive value different than 0 (and will contribute to the total resource utilization).
|
|
217
|
+
ru_sum_vector = np.array([
|
|
218
|
+
np.sum(indicated_ru_matrix[i], axis=0) + # sum of metric values over all configurations in a row
|
|
219
|
+
search_manager.min_ru[target][i] for i in range(indicated_ru_matrix.shape[0])])
|
|
220
|
+
|
|
221
|
+
# search_manager.compute_ru_functions contains a pair of ru_metric and ru_aggregation for each ru target
|
|
222
|
+
# get aggregated ru, considering both configurable and non-configurable nodes
|
|
223
|
+
if non_conf_ru_vector is None or len(non_conf_ru_vector) == 0:
|
|
224
|
+
aggr_ru = search_manager.compute_ru_functions[target][1](ru_sum_vector)
|
|
223
225
|
else:
|
|
224
|
-
|
|
226
|
+
aggr_ru = search_manager.compute_ru_functions[target][1](np.concatenate([ru_sum_vector, non_conf_ru_vector]))
|
|
225
227
|
|
|
226
|
-
for v in
|
|
228
|
+
for v in aggr_ru:
|
|
227
229
|
if isinstance(v, float):
|
|
228
|
-
if v >
|
|
229
|
-
Logger.critical(
|
|
230
|
+
if v > target_resource_utilization_value:
|
|
231
|
+
Logger.critical(
|
|
232
|
+
f"The model cannot be quantized to meet the specified target resource utilization {target.value} "
|
|
233
|
+
f"with the value {target_resource_utilization_value}.") # pragma: no cover
|
|
230
234
|
else:
|
|
231
|
-
lp_problem += v <=
|
|
235
|
+
lp_problem += v <= target_resource_utilization_value
|
|
232
236
|
|
|
233
237
|
|
|
234
238
|
def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
|
|
235
|
-
|
|
239
|
+
target_resource_utilization: ResourceUtilization,
|
|
236
240
|
eps: float = EPS) -> Dict[int, Dict[int, float]]:
|
|
237
241
|
"""
|
|
238
242
|
This function measures the sensitivity of a change in a bitwidth of a layer on the entire model.
|
|
@@ -244,8 +248,8 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
|
|
|
244
248
|
|
|
245
249
|
Args:
|
|
246
250
|
search_manager: MixedPrecisionSearchManager object to be used for problem formalization.
|
|
247
|
-
|
|
248
|
-
consumption).
|
|
251
|
+
target_resource_utilization: ResourceUtilization to constrain our LP problem with some resources limitations
|
|
252
|
+
(like model' weights memory consumption).
|
|
249
253
|
eps: Epsilon value to manually increase metric value (if necessary) for numerical stability
|
|
250
254
|
|
|
251
255
|
Returns:
|
|
@@ -257,30 +261,30 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
|
|
|
257
261
|
Logger.info('Starting to evaluate metrics')
|
|
258
262
|
layer_to_metrics_mapping = {}
|
|
259
263
|
|
|
260
|
-
|
|
264
|
+
is_bops_target_resource_utilization = target_resource_utilization.bops < np.inf
|
|
261
265
|
|
|
262
|
-
if
|
|
263
|
-
origin_max_config = search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(search_manager.
|
|
266
|
+
if is_bops_target_resource_utilization:
|
|
267
|
+
origin_max_config = search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(search_manager.max_ru_config)
|
|
264
268
|
max_config_value = search_manager.compute_metric_fn(origin_max_config)
|
|
265
269
|
else:
|
|
266
|
-
max_config_value = search_manager.compute_metric_fn(search_manager.
|
|
270
|
+
max_config_value = search_manager.compute_metric_fn(search_manager.max_ru_config)
|
|
267
271
|
|
|
268
272
|
for node_idx, layer_possible_bitwidths_indices in tqdm(search_manager.layer_to_bitwidth_mapping.items(),
|
|
269
273
|
total=len(search_manager.layer_to_bitwidth_mapping)):
|
|
270
274
|
layer_to_metrics_mapping[node_idx] = {}
|
|
271
275
|
|
|
272
276
|
for bitwidth_idx in layer_possible_bitwidths_indices:
|
|
273
|
-
if search_manager.
|
|
277
|
+
if search_manager.max_ru_config[node_idx] == bitwidth_idx:
|
|
274
278
|
# This is a computation of the metric for the max configuration, assign pre-calculated value
|
|
275
279
|
layer_to_metrics_mapping[node_idx][bitwidth_idx] = max_config_value
|
|
276
280
|
continue
|
|
277
281
|
|
|
278
282
|
# Create a configuration that differs at one layer only from the baseline model
|
|
279
|
-
mp_model_configuration = search_manager.
|
|
283
|
+
mp_model_configuration = search_manager.max_ru_config.copy()
|
|
280
284
|
mp_model_configuration[node_idx] = bitwidth_idx
|
|
281
285
|
|
|
282
286
|
# Build a distance matrix using the function we got from the framework implementation.
|
|
283
|
-
if
|
|
287
|
+
if is_bops_target_resource_utilization:
|
|
284
288
|
# Reconstructing original graph's configuration from virtual graph's configuration
|
|
285
289
|
origin_mp_model_configuration = \
|
|
286
290
|
search_manager.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
|
|
@@ -297,7 +301,7 @@ def _build_layer_to_metrics_mapping(search_manager: MixedPrecisionSearchManager,
|
|
|
297
301
|
metric_value = search_manager.compute_metric_fn(
|
|
298
302
|
mp_model_configuration,
|
|
299
303
|
[node_idx],
|
|
300
|
-
search_manager.
|
|
304
|
+
search_manager.max_ru_config)
|
|
301
305
|
|
|
302
306
|
layer_to_metrics_mapping[node_idx][bitwidth_idx] = max(metric_value, max_config_value + eps)
|
|
303
307
|
|
|
@@ -78,8 +78,7 @@ class SensitivityEvaluation:
|
|
|
78
78
|
self.disable_activation_for_metric = disable_activation_for_metric
|
|
79
79
|
if self.quant_config.use_hessian_based_scores:
|
|
80
80
|
if not isinstance(hessian_info_service, HessianInfoService):
|
|
81
|
-
Logger.
|
|
82
|
-
f" an HessianInfoService object must be provided but is {hessian_info_service}")
|
|
81
|
+
Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
|
|
83
82
|
self.hessian_info_service = hessian_info_service
|
|
84
83
|
|
|
85
84
|
self.sorted_configurable_nodes_names = graph.get_configurable_sorted_nodes_names(self.fw_info)
|
|
@@ -320,8 +319,7 @@ class SensitivityEvaluation:
|
|
|
320
319
|
node_name = sorted_configurable_nodes_names[node_idx_to_configure]
|
|
321
320
|
layers_to_config = self.conf_node2layers.get(node_name, None)
|
|
322
321
|
if layers_to_config is None:
|
|
323
|
-
Logger.
|
|
324
|
-
f"Couldn't find matching layers in the MP model for node {node_name}.") # pragma: no cover
|
|
322
|
+
Logger.critical(f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover
|
|
325
323
|
|
|
326
324
|
for current_layer in layers_to_config:
|
|
327
325
|
self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure])
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import List
|
|
17
17
|
|
|
18
|
-
from model_compression_toolkit.core import
|
|
18
|
+
from model_compression_toolkit.core import ResourceUtilization
|
|
19
19
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
|
|
20
20
|
MixedPrecisionSearchManager
|
|
21
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
|
@@ -26,29 +26,29 @@ import numpy as np
|
|
|
26
26
|
|
|
27
27
|
def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
28
28
|
search_manager: MixedPrecisionSearchManager,
|
|
29
|
-
|
|
29
|
+
target_resource_utilization: ResourceUtilization) -> List[int]:
|
|
30
30
|
"""
|
|
31
31
|
A greedy procedure to try and improve a mixed-precision solution that was found by a mixed-precision optimization
|
|
32
32
|
algorithm.
|
|
33
33
|
This procedure tries to increase the bit-width precision of configurable nodes that did not get the maximal
|
|
34
34
|
candidate
|
|
35
35
|
in the found solution.
|
|
36
|
-
It iteratively goes over all such nodes, computes the
|
|
37
|
-
best candidate), filters out all configs that hold the
|
|
36
|
+
It iteratively goes over all such nodes, computes the resource utilization values on a modified configuration (with the node's next
|
|
37
|
+
best candidate), filters out all configs that hold the resource utilization constraints and chooses one of them as an improvement
|
|
38
38
|
step
|
|
39
|
-
The choice is done in a greedy approach where we take the configuration that modifies the
|
|
39
|
+
The choice is done in a greedy approach where we take the configuration that modifies the resource utilization the least.
|
|
40
40
|
|
|
41
41
|
Args:
|
|
42
42
|
mp_solution: A mixed-precision configuration that was found by a mixed-precision optimization algorithm.
|
|
43
43
|
search_manager: A MixedPrecisionSearchManager object.
|
|
44
|
-
|
|
44
|
+
target_resource_utilization: The target resource utilization for the mixed-precision search.
|
|
45
45
|
|
|
46
46
|
Returns: A new, possibly updated, mixed-precision bit-width configuration.
|
|
47
47
|
|
|
48
48
|
"""
|
|
49
|
-
# Refinement is not supported for BOPs
|
|
50
|
-
if
|
|
51
|
-
Logger.info(f'Target
|
|
49
|
+
# Refinement is not supported for BOPs utilization for now...
|
|
50
|
+
if target_resource_utilization.bops < np.inf:
|
|
51
|
+
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
|
|
52
52
|
return mp_solution
|
|
53
53
|
|
|
54
54
|
new_solution = mp_solution.copy()
|
|
@@ -56,7 +56,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
|
56
56
|
|
|
57
57
|
while changed:
|
|
58
58
|
changed = False
|
|
59
|
-
|
|
59
|
+
nodes_ru = {}
|
|
60
60
|
nodes_next_candidate = {}
|
|
61
61
|
|
|
62
62
|
for node_idx in range(len(mp_solution)):
|
|
@@ -72,32 +72,32 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
|
|
|
72
72
|
kernel_attr = None if kernel_attr is None else kernel_attr[0]
|
|
73
73
|
valid_candidates = _get_valid_candidates_indices(node_candidates, new_solution[node_idx], kernel_attr)
|
|
74
74
|
|
|
75
|
-
# Create a list of
|
|
76
|
-
|
|
75
|
+
# Create a list of ru for the valid candidates.
|
|
76
|
+
updated_ru = []
|
|
77
77
|
for valid_idx in valid_candidates:
|
|
78
|
-
|
|
78
|
+
node_updated_ru = search_manager.compute_resource_utilization_for_config(
|
|
79
79
|
config=search_manager.replace_config_in_index(new_solution, node_idx, valid_idx))
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
# filter out new configs that don't hold the
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if len(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
nodes_next_candidate[node_idx] =
|
|
92
|
-
|
|
93
|
-
if len(
|
|
94
|
-
# filter out new configs that don't hold the
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
node_idx_to_upgrade =
|
|
80
|
+
updated_ru.append(node_updated_ru)
|
|
81
|
+
|
|
82
|
+
# filter out new configs that don't hold the resource utilization restrictions
|
|
83
|
+
node_filtered_ru = [(node_idx, ru) for node_idx, ru in zip(valid_candidates, updated_ru) if
|
|
84
|
+
target_resource_utilization.holds_constraints(ru)]
|
|
85
|
+
|
|
86
|
+
if len(node_filtered_ru) > 0:
|
|
87
|
+
sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory,
|
|
88
|
+
node_ru[1].weights_memory,
|
|
89
|
+
node_ru[1].activation_memory))
|
|
90
|
+
nodes_ru[node_idx] = sorted_by_ru[0][1]
|
|
91
|
+
nodes_next_candidate[node_idx] = sorted_by_ru[0][0]
|
|
92
|
+
|
|
93
|
+
if len(nodes_ru) > 0:
|
|
94
|
+
# filter out new configs that don't hold the ru restrictions
|
|
95
|
+
node_filtered_ru = [(node_idx, ru) for node_idx, ru in nodes_ru.items()]
|
|
96
|
+
sorted_by_ru = sorted(node_filtered_ru, key=lambda node_ru: (node_ru[1].total_memory,
|
|
97
|
+
node_ru[1].weights_memory,
|
|
98
|
+
node_ru[1].activation_memory))
|
|
99
|
+
|
|
100
|
+
node_idx_to_upgrade = sorted_by_ru[0][0]
|
|
101
101
|
new_solution[node_idx_to_upgrade] = nodes_next_candidate[node_idx_to_upgrade]
|
|
102
102
|
changed = True
|
|
103
103
|
|
|
@@ -158,9 +158,9 @@ class ModelCollector:
|
|
|
158
158
|
for td, sc in zip(tensor_data, self.stats_containers_list):
|
|
159
159
|
if isinstance(sc, (list, tuple)):
|
|
160
160
|
if not isinstance(td, (list, tuple)):
|
|
161
|
-
Logger.
|
|
161
|
+
Logger.critical('\'tensor_data\' must be a list or a tuple if \'stats_containers_list\' contains lists or tuples.') # pragma: no cover
|
|
162
162
|
if len(sc) != len(td):
|
|
163
|
-
Logger.
|
|
163
|
+
Logger.critical('\'tensor_data\' and \'stats_containers_list\' must have matching lengths') # pragma: no cover
|
|
164
164
|
for tdi, sci in zip(td, sc):
|
|
165
165
|
sci.update_statistics(self.fw_impl.to_numpy(tdi))
|
|
166
166
|
else:
|
|
@@ -305,7 +305,7 @@ class ChangeCandidatesActivationQuantizationMethod(BaseAction):
|
|
|
305
305
|
self.activation_quantization_method)
|
|
306
306
|
|
|
307
307
|
if activation_quantization_fn is None:
|
|
308
|
-
|
|
308
|
+
Logger.critical('Unknown activation quantization method specified.') # pragma: no cover
|
|
309
309
|
|
|
310
310
|
qc.activation_quantization_cfg.set_activation_quantization_fn(activation_quantization_fn)
|
|
311
311
|
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
|
|
@@ -352,7 +352,7 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
|
|
|
352
352
|
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
|
353
353
|
|
|
354
354
|
if weights_quantization_fn is None:
|
|
355
|
-
|
|
355
|
+
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
|
356
356
|
|
|
357
357
|
(node.final_weights_quantization_cfg.get_attr_config(self.attr_name)
|
|
358
358
|
.set_weights_quantization_fn(weights_quantization_fn))
|
|
@@ -401,7 +401,7 @@ class ChangeCandidatesWeightsQuantizationMethod(BaseAction):
|
|
|
401
401
|
weights_quantization_fn = get_weights_quantization_fn(self.weights_quantization_method)
|
|
402
402
|
|
|
403
403
|
if weights_quantization_fn is None:
|
|
404
|
-
|
|
404
|
+
Logger.critical('Unknown weights quantization method specified.') # pragma: no cover
|
|
405
405
|
|
|
406
406
|
attr_qc.set_weights_quantization_fn(weights_quantization_fn)
|
|
407
407
|
attr_qc.weights_quantization_method = self.weights_quantization_method
|
|
@@ -18,7 +18,7 @@ from typing import List, Dict, Tuple
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
20
20
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
21
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
21
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
22
22
|
from model_compression_toolkit.core.common.pruning.mask.per_channel_mask import MaskIndicator
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
|
@@ -30,16 +30,16 @@ from model_compression_toolkit.target_platform_capabilities.target_platform impo
|
|
|
30
30
|
class GreedyMaskCalculator:
|
|
31
31
|
"""
|
|
32
32
|
GreedyMaskCalculator calculates pruning masks for prunable nodes to meet a
|
|
33
|
-
specified target
|
|
33
|
+
specified target resource utilization. It employs a greedy approach to selectively unprune channel
|
|
34
34
|
groups (SIMD groups) based on their importance scores. Initially, all channels are
|
|
35
35
|
pruned (mask set to zero), and the calculator iteratively adds back the most significant
|
|
36
|
-
channel groups until the memory footprint meets the target
|
|
36
|
+
channel groups until the memory footprint meets the target resource utilization or all channels are unpruned.
|
|
37
37
|
"""
|
|
38
38
|
def __init__(self,
|
|
39
39
|
prunable_nodes: List[BaseNode],
|
|
40
40
|
fw_info: FrameworkInfo,
|
|
41
41
|
simd_groups_scores: Dict[BaseNode, np.ndarray],
|
|
42
|
-
|
|
42
|
+
target_resource_utilization: ResourceUtilization,
|
|
43
43
|
graph: Graph,
|
|
44
44
|
fw_impl: PruningFrameworkImplementation,
|
|
45
45
|
tpc: TargetPlatformCapabilities,
|
|
@@ -49,7 +49,7 @@ class GreedyMaskCalculator:
|
|
|
49
49
|
prunable_nodes (List[BaseNode]): Nodes that are eligible for pruning.
|
|
50
50
|
fw_info (FrameworkInfo): Framework-specific information and utilities.
|
|
51
51
|
simd_groups_scores (Dict[BaseNode, np.ndarray]): Importance scores for each SIMG group in a prunable node.
|
|
52
|
-
|
|
52
|
+
target_resource_utilization (ResourceUtilization): The target resource utilization to achieve.
|
|
53
53
|
graph (Graph): The computational graph of the model.
|
|
54
54
|
fw_impl (PruningFrameworkImplementation): Framework-specific implementation details.
|
|
55
55
|
tpc (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
|
|
@@ -57,7 +57,7 @@ class GreedyMaskCalculator:
|
|
|
57
57
|
"""
|
|
58
58
|
self.prunable_nodes = prunable_nodes
|
|
59
59
|
self.fw_info = fw_info
|
|
60
|
-
self.
|
|
60
|
+
self.target_resource_utilization = target_resource_utilization
|
|
61
61
|
self.graph = graph
|
|
62
62
|
self.fw_impl = fw_impl
|
|
63
63
|
self.tpc = tpc
|
|
@@ -86,18 +86,18 @@ class GreedyMaskCalculator:
|
|
|
86
86
|
def compute_mask(self):
|
|
87
87
|
"""
|
|
88
88
|
Computes the pruning mask by iteratively adding SIMD groups to unpruned state
|
|
89
|
-
based on their importance and the target
|
|
89
|
+
based on their importance and the target resource utilization.
|
|
90
90
|
"""
|
|
91
91
|
# Iteratively unprune the graph while monitoring the memory footprint.
|
|
92
92
|
current_memory = self.memory_calculator.get_pruned_graph_memory(masks=self.oc_pruning_mask.get_mask(),
|
|
93
93
|
include_padded_channels=self.tpc.is_simd_padding)
|
|
94
|
-
if current_memory > self.
|
|
95
|
-
Logger.
|
|
96
|
-
|
|
94
|
+
if current_memory > self.target_resource_utilization.weights_memory:
|
|
95
|
+
Logger.critical(f"Insufficient memory for the target resource utilization: current memory {current_memory}, "
|
|
96
|
+
f"target memory {self.target_resource_utilization.weights_memory}.")
|
|
97
97
|
|
|
98
98
|
# Greedily unprune groups (by setting their mask to 1) until the memory target is met
|
|
99
99
|
# or all channels unpruned.
|
|
100
|
-
while current_memory < self.
|
|
100
|
+
while current_memory < self.target_resource_utilization.weights_memory and self.oc_pruning_mask.has_pruned_channel():
|
|
101
101
|
# Select the best SIMD group (best means highest score which means most sensitive group)
|
|
102
102
|
# to add based on the scores.
|
|
103
103
|
node_to_remain, group_to_remain_idx = self._get_most_sensitive_simd_group_candidate()
|
|
@@ -108,7 +108,7 @@ class GreedyMaskCalculator:
|
|
|
108
108
|
include_padded_channels=self.tpc.is_simd_padding)
|
|
109
109
|
|
|
110
110
|
# If the target memory is exceeded, revert the last addition.
|
|
111
|
-
if current_memory > self.
|
|
111
|
+
if current_memory > self.target_resource_utilization.weights_memory:
|
|
112
112
|
self.oc_pruning_mask.set_mask_value_for_simd_group(node=node_to_remain,
|
|
113
113
|
group_index=group_to_remain_idx,
|
|
114
114
|
mask_indicator=MaskIndicator.PRUNED)
|
|
@@ -250,13 +250,13 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
|
250
250
|
kernel_attr = self.fw_info.get_kernel_op_attributes(entry_node.type)
|
|
251
251
|
# Ensure only one kernel attribute exists for the given node.
|
|
252
252
|
if len(kernel_attr) != 1:
|
|
253
|
-
Logger.
|
|
253
|
+
Logger.critical(f"Expected a single attribute but found multiple attributes ({len(kernel_attr)}) for node {entry_node}.")
|
|
254
254
|
kernel_attr = kernel_attr[0]
|
|
255
255
|
|
|
256
256
|
# Retrieve and validate the axis index for the output channels.
|
|
257
257
|
oc_axis, _ = self.fw_info.kernel_channels_mapping.get(entry_node.type)
|
|
258
258
|
if oc_axis is None or int(oc_axis) != oc_axis:
|
|
259
|
-
Logger.
|
|
259
|
+
Logger.critical(f"Invalid output channel axis type for node {entry_node}: expected integer but got {oc_axis}.")
|
|
260
260
|
|
|
261
261
|
# Get the number of output channels based on the kernel attribute and axis.
|
|
262
262
|
num_oc = entry_node.get_weights_by_keys(kernel_attr[0]).shape[oc_axis]
|
|
@@ -19,7 +19,7 @@ from typing import List, Dict, Tuple
|
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
21
21
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
22
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
22
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
|
25
25
|
from model_compression_toolkit.logger import Logger
|
|
@@ -73,7 +73,7 @@ class PerChannelMask:
|
|
|
73
73
|
mask_indicator: The new value to set in the mask (either PRUNED or REMAINED).
|
|
74
74
|
"""
|
|
75
75
|
if mask_indicator not in [MaskIndicator.PRUNED, MaskIndicator.REMAINED]:
|
|
76
|
-
Logger.
|
|
76
|
+
Logger.critical("Mask value must be either 'MaskIndicator.PRUNED' or 'MaskIndicator.REMAINED'")
|
|
77
77
|
self._mask[node][channel_idx] = mask_indicator.value
|
|
78
78
|
|
|
79
79
|
def has_pruned_channel(self) -> bool:
|
|
@@ -18,7 +18,7 @@ from typing import List, Dict, Tuple
|
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
|
20
20
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
21
|
-
from model_compression_toolkit.core.common.mixed_precision.
|
|
21
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
|
22
22
|
from model_compression_toolkit.core.common.pruning.mask.per_channel_mask import PerChannelMask, MaskIndicator
|
|
23
23
|
from model_compression_toolkit.core.common.pruning.memory_calculator import MemoryCalculator
|
|
24
24
|
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import PruningFrameworkImplementation
|
|
@@ -79,7 +79,7 @@ class PerSIMDGroupMask:
|
|
|
79
79
|
mask_indicator: The new value to set in the mask (either PRUNED or REMAINED).
|
|
80
80
|
"""
|
|
81
81
|
if mask_indicator not in [MaskIndicator.PRUNED, MaskIndicator.REMAINED]:
|
|
82
|
-
Logger.
|
|
82
|
+
Logger.critical("Mask value must be either 'MaskIndicator.PRUNED' or 'MaskIndicator.REMAINED'")
|
|
83
83
|
|
|
84
84
|
# Update the SIMD group mask and corresponding per-channel mask
|
|
85
85
|
self._mask_simd[node][group_index] = mask_indicator.value
|