mct-nightly 1.11.0.20240321.357__py3-none-any.whl → 1.11.0.20240323.408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/METADATA +17 -9
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/RECORD +152 -152
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +1 -1
- model_compression_toolkit/core/__init__.py +3 -3
- model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
- model_compression_toolkit/core/common/data_loader.py +3 -3
- model_compression_toolkit/core/common/graph/base_graph.py +10 -13
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/graph/edge.py +2 -1
- model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
- model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
- model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
- model_compression_toolkit/core/common/model_collector.py +2 -2
- model_compression_toolkit/core/common/network_editors/actions.py +3 -3
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
- model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
- model_compression_toolkit/core/common/pruning/pruner.py +7 -7
- model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
- model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
- model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
- model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
- model_compression_toolkit/core/common/user_info.py +1 -1
- model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
- model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
- model_compression_toolkit/core/keras/keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
- model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
- model_compression_toolkit/core/keras/reader/common.py +2 -2
- model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
- model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
- model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
- model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
- model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
- model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
- model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
- model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
- model_compression_toolkit/core/pytorch/utils.py +3 -2
- model_compression_toolkit/core/runner.py +43 -42
- model_compression_toolkit/data_generation/common/data_generation.py +18 -18
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
- model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
- model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
- model_compression_toolkit/gptq/common/gptq_training.py +14 -12
- model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
- model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
- model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
- model_compression_toolkit/logger.py +1 -13
- model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
- model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
- model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
- model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
- model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
- model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
- model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
- model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
- model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
- model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
- model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
- model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
|
@@ -14,37 +14,72 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
from typing import Any
|
|
17
|
+
from typing import Any, List, Union
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat
|
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
|
|
20
|
+
OperatorsSet
|
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class Fusing(TargetPlatformModelComponent):
|
|
25
|
+
"""
|
|
26
|
+
Fusing defines a list of operators that should be combined and treated as a single operator,
|
|
27
|
+
hence no quantization is applied between them.
|
|
28
|
+
"""
|
|
24
29
|
|
|
25
|
-
def __init__(self,
|
|
30
|
+
def __init__(self,
|
|
31
|
+
operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
|
|
32
|
+
name: str = None):
|
|
33
|
+
"""
|
|
34
|
+
Args:
|
|
35
|
+
operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
|
|
36
|
+
name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
|
|
37
|
+
"""
|
|
26
38
|
assert isinstance(operator_groups_list,
|
|
27
39
|
list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
|
|
28
40
|
assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
|
|
41
|
+
|
|
42
|
+
# Generate a name from the operator groups if no name is provided
|
|
29
43
|
if name is None:
|
|
30
44
|
name = '_'.join([x.name for x in operator_groups_list])
|
|
45
|
+
|
|
31
46
|
super().__init__(name)
|
|
32
47
|
self.operator_groups_list = operator_groups_list
|
|
33
48
|
|
|
34
|
-
def contains(self, other: Any):
|
|
49
|
+
def contains(self, other: Any) -> bool:
|
|
50
|
+
"""
|
|
51
|
+
Determines if the current Fusing instance contains another Fusing instance.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
other: The other Fusing instance to check against.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
A boolean indicating whether the other instance is contained within this one.
|
|
58
|
+
"""
|
|
35
59
|
if not isinstance(other, Fusing):
|
|
36
60
|
return False
|
|
61
|
+
|
|
62
|
+
# Check for containment by comparing operator groups
|
|
37
63
|
for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
|
|
38
64
|
for j in range(len(other.operator_groups_list)):
|
|
39
|
-
if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
|
|
65
|
+
if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
|
|
66
|
+
isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
|
|
67
|
+
other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
|
|
40
68
|
break
|
|
41
69
|
else:
|
|
70
|
+
# If all checks pass, the other Fusing instance is contained
|
|
42
71
|
return True
|
|
72
|
+
# Other Fusing instance is not contained
|
|
43
73
|
return False
|
|
44
74
|
|
|
45
|
-
|
|
46
75
|
def get_info(self):
|
|
76
|
+
"""
|
|
77
|
+
Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
|
|
81
|
+
or just the sequence of operator groups if no name is set.
|
|
82
|
+
"""
|
|
47
83
|
if self.name is not None:
|
|
48
84
|
return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
|
|
49
|
-
return ' -> '.join([x.name for x in self.operator_groups_list])
|
|
50
|
-
|
|
85
|
+
return ' -> '.join([x.name for x in self.operator_groups_list])
|
model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py
CHANGED
|
@@ -124,7 +124,7 @@ class OpQuantizationConfig:
|
|
|
124
124
|
|
|
125
125
|
Args:
|
|
126
126
|
default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation.
|
|
127
|
-
attr_weights_configs_mapping (
|
|
127
|
+
attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
|
|
128
128
|
activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
|
|
129
129
|
activation_n_bits (int): Number of bits to quantize the activations.
|
|
130
130
|
enable_activation_quantization (bool): Whether to quantize the model activations or not.
|
|
@@ -215,24 +215,19 @@ class QuantizationConfigOptions(object):
|
|
|
215
215
|
"""
|
|
216
216
|
|
|
217
217
|
assert isinstance(quantization_config_list,
|
|
218
|
-
list), f'QuantizationConfigOptions options list
|
|
219
|
-
|
|
220
|
-
assert len(quantization_config_list) > 0, f'Options list can not be empty'
|
|
218
|
+
list), f'\'QuantizationConfigOptions\' options list must be a list, but received: {type(quantization_config_list)}.'
|
|
219
|
+
assert len(quantization_config_list) > 0, f'Options list can not be empty.'
|
|
221
220
|
for cfg in quantization_config_list:
|
|
222
|
-
assert isinstance(cfg, OpQuantizationConfig), f'
|
|
223
|
-
f'but found an object type: {type(cfg)}'
|
|
221
|
+
assert isinstance(cfg, OpQuantizationConfig), f'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: {type(cfg)}.'
|
|
224
222
|
self.quantization_config_list = quantization_config_list
|
|
225
223
|
if len(quantization_config_list) > 1:
|
|
226
|
-
assert base_config is not None, f'
|
|
227
|
-
|
|
228
|
-
assert base_config in quantization_config_list, f"base_config must be in the given quantization config " \
|
|
229
|
-
f"list of options"
|
|
224
|
+
assert base_config is not None, f'For multiple configurations, a \'base_config\' is required for non-mixed-precision optimization.'
|
|
225
|
+
assert base_config in quantization_config_list, f"\'base_config\' must be included in the quantization config options list."
|
|
230
226
|
self.base_config = base_config
|
|
231
227
|
elif len(quantization_config_list) == 1:
|
|
232
228
|
self.base_config = quantization_config_list[0]
|
|
233
229
|
else:
|
|
234
|
-
|
|
235
|
-
"defined in its options list, but list is empty")
|
|
230
|
+
Logger.critical("\'QuantizationConfigOptions\' requires at least one \'OpQuantizationConfig\'; the provided list is empty.")
|
|
236
231
|
|
|
237
232
|
def __eq__(self, other):
|
|
238
233
|
"""
|
|
@@ -280,13 +275,13 @@ class QuantizationConfigOptions(object):
|
|
|
280
275
|
attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
|
|
281
276
|
else:
|
|
282
277
|
if not isinstance(attrs, List):
|
|
283
|
-
Logger.
|
|
278
|
+
Logger.critical(f"Expected a list of attributes but received {type(attrs)}.")
|
|
284
279
|
attrs_to_update = attrs
|
|
285
280
|
|
|
286
281
|
for attr in attrs_to_update:
|
|
287
282
|
if qc.attr_weights_configs_mapping.get(attr) is None:
|
|
288
|
-
Logger.
|
|
289
|
-
|
|
283
|
+
Logger.critical(f'Editing attributes is only possible for existing attributes in the configuration\'s '
|
|
284
|
+
f'weights config mapping; {attr} does not exist in {qc}.')
|
|
290
285
|
self.__edit_quantization_configuration(qc.attr_weights_configs_mapping[attr], kwargs)
|
|
291
286
|
return qc_options
|
|
292
287
|
|
|
@@ -312,7 +307,7 @@ class QuantizationConfigOptions(object):
|
|
|
312
307
|
for attr in list(qc.attr_weights_configs_mapping.keys()):
|
|
313
308
|
new_key = layer_attrs_mapping.get(attr)
|
|
314
309
|
if new_key is None:
|
|
315
|
-
Logger.
|
|
310
|
+
Logger.critical(f"Attribute \'{attr}\' does not exist in the provided attribute mapping.")
|
|
316
311
|
|
|
317
312
|
new_attr_mapping[new_key] = qc.attr_weights_configs_mapping.pop(attr)
|
|
318
313
|
|
|
@@ -323,8 +318,8 @@ class QuantizationConfigOptions(object):
|
|
|
323
318
|
def __edit_quantization_configuration(self, qc, kwargs):
|
|
324
319
|
for k, v in kwargs.items():
|
|
325
320
|
assert hasattr(qc,
|
|
326
|
-
k), f'
|
|
327
|
-
|
|
321
|
+
k), (f'Editing is only possible for existing attributes in the configuration; '
|
|
322
|
+
f'{k} is not an attribute of {qc}.')
|
|
328
323
|
setattr(qc, k, v)
|
|
329
324
|
|
|
330
325
|
def get_info(self):
|
model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py
CHANGED
|
@@ -156,7 +156,7 @@ class TargetPlatformModel(ImmutableClass):
|
|
|
156
156
|
elif isinstance(tp_model_component, OperatorsSetBase):
|
|
157
157
|
self.operator_set.append(tp_model_component)
|
|
158
158
|
else:
|
|
159
|
-
|
|
159
|
+
Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
|
|
160
160
|
|
|
161
161
|
def __enter__(self):
|
|
162
162
|
"""
|
|
@@ -192,7 +192,7 @@ class TargetPlatformModel(ImmutableClass):
|
|
|
192
192
|
"""
|
|
193
193
|
opsets_names = [op.name for op in self.operator_set]
|
|
194
194
|
if (len(set(opsets_names)) != len(opsets_names)):
|
|
195
|
-
Logger.
|
|
195
|
+
Logger.critical(f'Operator Sets must have unique names.')
|
|
196
196
|
|
|
197
197
|
def get_default_config(self) -> OpQuantizationConfig:
|
|
198
198
|
"""
|
|
@@ -87,7 +87,7 @@ class AttributeFilter(Filter):
|
|
|
87
87
|
"""
|
|
88
88
|
|
|
89
89
|
if not isinstance(other, AttributeFilter):
|
|
90
|
-
Logger.
|
|
90
|
+
Logger.critical("Not an attribute filter. Cannot perform an 'OR' operation.") # pragma: no cover
|
|
91
91
|
return OrAttributeFilter(self, other)
|
|
92
92
|
|
|
93
93
|
def __and__(self, other: Any):
|
|
@@ -101,7 +101,7 @@ class AttributeFilter(Filter):
|
|
|
101
101
|
AndAttributeFilter that filters with AND between the current AttributeFilter and the passed AttributeFilter.
|
|
102
102
|
"""
|
|
103
103
|
if not isinstance(other, AttributeFilter):
|
|
104
|
-
Logger.
|
|
104
|
+
Logger.critical("Not an attribute filter. Can not perform an 'AND' operation.") # pragma: no cover
|
|
105
105
|
return AndAttributeFilter(self, other)
|
|
106
106
|
|
|
107
107
|
def match(self,
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from model_compression_toolkit.logger import Logger
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def get_current_tpc():
|
|
@@ -38,7 +39,7 @@ class _CurrentTPC(object):
|
|
|
38
39
|
|
|
39
40
|
"""
|
|
40
41
|
if self.tpc is None:
|
|
41
|
-
|
|
42
|
+
Logger.critical("'TargetPlatformCapabilities' (TPC) instance is not initialized.")
|
|
42
43
|
return self.tpc
|
|
43
44
|
|
|
44
45
|
def reset(self):
|
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.targ
|
|
|
20
20
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
|
|
21
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
|
|
22
22
|
OperatorsSetBase
|
|
23
|
+
from model_compression_toolkit import DefaultDict
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
|
|
@@ -29,15 +30,14 @@ class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
|
|
|
29
30
|
def __init__(self,
|
|
30
31
|
op_set_name: str,
|
|
31
32
|
layers: List[Any],
|
|
32
|
-
attr_mapping: Dict[str,
|
|
33
|
+
attr_mapping: Dict[str, DefaultDict] = None):
|
|
33
34
|
"""
|
|
34
35
|
|
|
35
36
|
Args:
|
|
36
37
|
op_set_name (str): Name of OperatorsSet to associate with layers.
|
|
37
38
|
layers (List[Any]): List of layers/FilterLayerParams to associate with OperatorsSet.
|
|
38
|
-
attr_mapping (
|
|
39
|
-
|
|
40
|
-
dependency).
|
|
39
|
+
attr_mapping (Dict[str, DefaultDict]): A mapping between a general attribute name to a DefaultDict that maps a layer type to the layer's framework name of this attribute.
|
|
40
|
+
|
|
41
41
|
"""
|
|
42
42
|
self.layers = layers
|
|
43
43
|
self.attr_mapping = attr_mapping
|
|
@@ -147,7 +147,7 @@ class OperationsToLayers:
|
|
|
147
147
|
for layer in ops2layers.layers:
|
|
148
148
|
qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
|
|
149
149
|
if layer in existing_layers:
|
|
150
|
-
Logger.
|
|
150
|
+
Logger.critical(f'Found layer {layer.__name__} in more than one '
|
|
151
151
|
f'OperatorsSet') # pragma: no cover
|
|
152
152
|
else:
|
|
153
153
|
existing_layers.update({layer: qco_by_opset_name})
|
|
@@ -139,8 +139,7 @@ class TargetPlatformCapabilities(ImmutableClass):
|
|
|
139
139
|
if isinstance(tpc_component, OperationsSetToLayers):
|
|
140
140
|
self.op_sets_to_layers += tpc_component
|
|
141
141
|
else:
|
|
142
|
-
Logger.
|
|
143
|
-
f'{type(tpc_component)}') # pragma: no cover
|
|
142
|
+
Logger.critical(f"Attempt to append an unrecognized 'TargetPlatformCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
|
|
144
143
|
|
|
145
144
|
def __enter__(self):
|
|
146
145
|
"""
|
|
@@ -55,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
55
55
|
for i, (k, v) in enumerate(self.get_sig().parameters.items()):
|
|
56
56
|
if i == 0:
|
|
57
57
|
if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
|
|
58
|
-
Logger.
|
|
58
|
+
Logger.critical(f"The first parameter must be either 'TrainableQuantizerWeightsConfig' or 'TrainableQuantizerActivationConfig'.") # pragma: no cover
|
|
59
59
|
elif v.default is v.empty:
|
|
60
|
-
Logger.
|
|
60
|
+
Logger.critical(f"Parameter '{k}' lacks a default value.") # pragma: no cover
|
|
61
61
|
|
|
62
62
|
super(BaseTrainableQuantizer, self).__init__()
|
|
63
63
|
self.quantization_config = quantization_config
|
|
@@ -67,22 +67,22 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
67
67
|
static_quantization_target = getattr(self, QUANTIZATION_TARGET, None)
|
|
68
68
|
|
|
69
69
|
if static_quantization_method is None or static_quantization_target is None:
|
|
70
|
-
Logger.
|
|
71
|
-
|
|
70
|
+
Logger.critical("Quantizer class inheriting from 'BaseTrainableQuantizer' is improperly defined. "
|
|
71
|
+
"Ensure it includes the '@mark_quantizer' decorator and is correctly applied.")
|
|
72
72
|
|
|
73
73
|
if static_quantization_target == QuantizationTarget.Weights:
|
|
74
74
|
self.validate_weights()
|
|
75
75
|
if self.quantization_config.weights_quantization_method not in static_quantization_method:
|
|
76
|
-
Logger.
|
|
77
|
-
f
|
|
76
|
+
Logger.critical(
|
|
77
|
+
f"Quantization method mismatch. Expected methods: {static_quantization_method}, received: {self.quantization_config.weights_quantization_method}.")
|
|
78
78
|
elif static_quantization_target == QuantizationTarget.Activation:
|
|
79
79
|
self.validate_activation()
|
|
80
80
|
if self.quantization_config.activation_quantization_method not in static_quantization_method:
|
|
81
|
-
Logger.
|
|
82
|
-
f
|
|
81
|
+
Logger.critical(
|
|
82
|
+
f"Quantization method mismatch. Expected methods: {static_quantization_method}, received: {self.quantization_config.activation_quantization_method}.")
|
|
83
83
|
else:
|
|
84
|
-
Logger.
|
|
85
|
-
f'
|
|
84
|
+
Logger.critical(
|
|
85
|
+
f"Unrecognized 'QuantizationTarget': {static_quantization_target}.") # pragma: no cover
|
|
86
86
|
|
|
87
87
|
self.quantizer_parameters = {}
|
|
88
88
|
|
|
@@ -145,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
145
145
|
|
|
146
146
|
"""
|
|
147
147
|
if self.activation_quantization() or not self.weights_quantization():
|
|
148
|
-
Logger.
|
|
148
|
+
Logger.critical(f'Expected weight quantization configuration; received activation quantization instead.')
|
|
149
149
|
|
|
150
150
|
def validate_activation(self) -> None:
|
|
151
151
|
"""
|
|
@@ -153,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
153
153
|
|
|
154
154
|
"""
|
|
155
155
|
if not self.activation_quantization() or self.weights_quantization():
|
|
156
|
-
Logger.
|
|
156
|
+
Logger.critical(f'Expected activation quantization configuration; received weight quantization instead.')
|
|
157
157
|
|
|
158
158
|
def convert2inferable(self) -> BaseInferableQuantizer:
|
|
159
159
|
"""
|
|
@@ -183,7 +183,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
|
|
|
183
183
|
if name in self.quantizer_parameters:
|
|
184
184
|
return self.quantizer_parameters[name][VAR]
|
|
185
185
|
else:
|
|
186
|
-
Logger.
|
|
186
|
+
Logger.critical(f"Variable '{name}' does not exist in quantizer parameters.") # pragma: no cover
|
|
187
187
|
|
|
188
188
|
|
|
189
189
|
@abstractmethod
|
|
@@ -36,7 +36,9 @@ def get_trainable_quantizer_weights_config(
|
|
|
36
36
|
TrainableQuantizerWeightsConfig: an object that contains the quantizer configuration
|
|
37
37
|
"""
|
|
38
38
|
if n.final_weights_quantization_cfg is None:
|
|
39
|
-
Logger.
|
|
39
|
+
Logger.critical(
|
|
40
|
+
"The node requires a 'final_weights_quantization_cfg' configuration to build a "
|
|
41
|
+
"quantizer. Please ensure this configuration is set for the node.")# pragma: no cover
|
|
40
42
|
|
|
41
43
|
final_node_cfg = n.final_weights_quantization_cfg
|
|
42
44
|
final_attr_cfg = final_node_cfg.get_attr_config(attr_name)
|
|
@@ -65,7 +67,9 @@ def get_trainable_quantizer_activation_config(
|
|
|
65
67
|
TrainableQuantizerActivationConfig - an object that contains the quantizer configuration
|
|
66
68
|
"""
|
|
67
69
|
if n.final_activation_quantization_cfg is None:
|
|
68
|
-
Logger.
|
|
70
|
+
Logger.critical(
|
|
71
|
+
"The node requires a 'final_activation_quantization_cfg' configuration to build a "
|
|
72
|
+
"quantizer. Please ensure this configuration is set for the node.")# pragma: no cover
|
|
69
73
|
|
|
70
74
|
final_cfg = n.final_activation_quantization_cfg
|
|
71
75
|
return TrainableQuantizerActivationConfig(final_cfg.activation_quantization_method,
|
|
@@ -93,17 +97,20 @@ def get_trainable_quantizer_quantization_candidates(n: BaseNode, attr: str = Non
|
|
|
93
97
|
if attr is not None:
|
|
94
98
|
# all candidates must have the same weights quantization method
|
|
95
99
|
weights_quantization_methods = set([cfg.weights_quantization_cfg.get_attr_config(attr).weights_quantization_method
|
|
96
|
-
|
|
100
|
+
for cfg in n.candidates_quantization_cfg])
|
|
97
101
|
if len(weights_quantization_methods) > 1:
|
|
98
|
-
Logger.
|
|
99
|
-
|
|
102
|
+
Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent weights "
|
|
103
|
+
f"quantization methods detected: {weights_quantization_methods}. "
|
|
104
|
+
f"Trainable quantizer requires all candidates to have the same weights "
|
|
105
|
+
f"quantization method.") # pragma: no cover
|
|
100
106
|
|
|
101
107
|
# all candidates must have the same activation quantization method
|
|
102
108
|
activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method
|
|
103
109
|
for cfg in n.candidates_quantization_cfg])
|
|
104
110
|
if len(activation_quantization_methods) > 1:
|
|
105
|
-
Logger.
|
|
106
|
-
|
|
111
|
+
Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent activation quantization "
|
|
112
|
+
f"methods detected: {activation_quantization_methods}. "
|
|
113
|
+
f"Trainable quantizer requires all candidates to have the same activation quantization method.")# pragma: no cover
|
|
107
114
|
|
|
108
115
|
# get unique lists of candidates
|
|
109
116
|
unique_weights_candidates = n.get_unique_weights_candidates(attr)
|
|
@@ -44,7 +44,7 @@ def get_trainable_quantizer_class(quant_target: QuantizationTarget,
|
|
|
44
44
|
"""
|
|
45
45
|
qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
|
|
46
46
|
if len(qat_quantizer_classes) == 0:
|
|
47
|
-
Logger.
|
|
47
|
+
Logger.critical(f"No quantizer classes inherited from {quantizer_base_class} were detected.") # pragma: no cover
|
|
48
48
|
|
|
49
49
|
filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET, None) is not None and
|
|
50
50
|
getattr(q_class, QUANTIZATION_TARGET) == quant_target and
|
|
@@ -54,9 +54,9 @@ def get_trainable_quantizer_class(quant_target: QuantizationTarget,
|
|
|
54
54
|
qat_quantizer_classes))
|
|
55
55
|
|
|
56
56
|
if len(filtered_quantizers) != 1:
|
|
57
|
-
Logger.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
57
|
+
Logger.critical(f"Found {len(filtered_quantizers)} quantizers for target {quant_target.value}, "
|
|
58
|
+
f"matching the requested quantization method {quant_method.name} and "
|
|
59
|
+
f"quantizer type {quantizer_id.value}, but exactly one is required. "
|
|
60
|
+
f"Identified quantizers: {filtered_quantizers}.")
|
|
61
61
|
|
|
62
62
|
return filtered_quantizers[0]
|
|
@@ -86,6 +86,5 @@ else:
|
|
|
86
86
|
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
87
87
|
|
|
88
88
|
super().__init__(quantization_config)
|
|
89
|
-
Logger.critical(
|
|
90
|
-
'
|
|
91
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
89
|
+
Logger.critical("Tensorflow must be installed to use BaseKerasTrainableQuantizer. "
|
|
90
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -48,8 +48,8 @@ if FOUND_TF:
|
|
|
48
48
|
KerasTrainableQuantizationWrapper.__name__: KerasTrainableQuantizationWrapper})
|
|
49
49
|
all_trainable_names = list(qi_trainable_custom_objects.keys())
|
|
50
50
|
if len(set(all_trainable_names)) < len(all_trainable_names):
|
|
51
|
-
Logger.
|
|
52
|
-
|
|
51
|
+
Logger.critical("Found multiple quantizers with identical names inheriting from "
|
|
52
|
+
"'BaseKerasTrainableQuantizer' while trying to load a model.")
|
|
53
53
|
|
|
54
54
|
qi_custom_objects = {**qi_trainable_custom_objects}
|
|
55
55
|
|
|
@@ -72,6 +72,5 @@ else:
|
|
|
72
72
|
Returns: A keras Model
|
|
73
73
|
|
|
74
74
|
"""
|
|
75
|
-
Logger.critical(
|
|
76
|
-
'
|
|
77
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
75
|
+
Logger.critical("Tensorflow must be installed to use keras_load_quantized_model. "
|
|
76
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -91,7 +91,7 @@ if FOUND_TF:
|
|
|
91
91
|
layer_weights_list[weight_keys.index(_weight_name(w.name))] = w
|
|
92
92
|
# Verify all the weights in the list are ready. The "set_weights" method expects all the layer's weights
|
|
93
93
|
if not all(w is not None for w in layer_weights_list):
|
|
94
|
-
Logger.
|
|
94
|
+
Logger.critical(f"Not all weights are set for layer '{self.layer.name}'")
|
|
95
95
|
assert all(w is not None for w in layer_weights_list)
|
|
96
96
|
inferable_quantizers_wrapper.set_weights(layer_weights_list)
|
|
97
97
|
|
|
@@ -110,6 +110,5 @@ else:
|
|
|
110
110
|
layer: A keras layer.
|
|
111
111
|
weights_quantizers: A dictionary between a weight's name to its quantizer.
|
|
112
112
|
"""
|
|
113
|
-
Logger.critical(
|
|
114
|
-
'
|
|
115
|
-
'Could not find Tensorflow package.') # pragma: no cover
|
|
113
|
+
Logger.critical("Tensorflow must be installed to use KerasTrainableQuantizationWrapper. "
|
|
114
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
|
@@ -60,6 +60,6 @@ else:
|
|
|
60
60
|
def __init__(self,
|
|
61
61
|
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
|
|
62
62
|
super().__init__(quantization_config)
|
|
63
|
-
Logger.critical(
|
|
64
|
-
'
|
|
65
|
-
|
|
63
|
+
Logger.critical("PyTorch must be installed to use 'BasePytorchTrainableQuantizer'. "
|
|
64
|
+
"The 'torch' package is missing.") # pragma: no cover
|
|
65
|
+
|
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
from enum import Enum
|
|
16
|
-
from typing import Dict, Any
|
|
17
|
-
|
|
18
|
-
import numpy as np
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class KPITarget(Enum):
|
|
22
|
-
"""
|
|
23
|
-
Targets for which we define KPIs metrics for mixed-precision search.
|
|
24
|
-
For each target that we care to consider in a mixed-precision search, there should be defined a set of
|
|
25
|
-
kpi computation function, kpi aggregation function, and kpi target (within a KPI object).
|
|
26
|
-
|
|
27
|
-
Whenever adding a kpi metric to KPI class we should add a matching target to this enum.
|
|
28
|
-
|
|
29
|
-
WEIGHTS - Weights memory KPI metric.
|
|
30
|
-
|
|
31
|
-
ACTIVATION - Activation memory KPI metric.
|
|
32
|
-
|
|
33
|
-
TOTAL - Total memory KPI metric.
|
|
34
|
-
|
|
35
|
-
BOPS - Total Bit-Operations KPI Metric.
|
|
36
|
-
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
WEIGHTS = 'weights'
|
|
40
|
-
ACTIVATION = 'activation'
|
|
41
|
-
TOTAL = 'total'
|
|
42
|
-
BOPS = 'bops'
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class KPI:
|
|
46
|
-
"""
|
|
47
|
-
Class to represent measurements of performance.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(self,
|
|
51
|
-
weights_memory: float = np.inf,
|
|
52
|
-
activation_memory: float = np.inf,
|
|
53
|
-
total_memory: float = np.inf,
|
|
54
|
-
bops: float = np.inf):
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
weights_memory: Memory of a model's weights in bytes. Note that this includes only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not).
|
|
59
|
-
activation_memory: Memory of a model's activation in bytes, according to the given activation kpi metric.
|
|
60
|
-
total_memory: The sum of model's activation and weights memory in bytes, according to the given total kpi metric.
|
|
61
|
-
bops: The total bit-operations in the model.
|
|
62
|
-
"""
|
|
63
|
-
self.weights_memory = weights_memory
|
|
64
|
-
self.activation_memory = activation_memory
|
|
65
|
-
self.total_memory = total_memory
|
|
66
|
-
self.bops = bops
|
|
67
|
-
|
|
68
|
-
def __repr__(self):
|
|
69
|
-
return f"Weights_memory: {self.weights_memory}, " \
|
|
70
|
-
f"Activation_memory: {self.activation_memory}, " \
|
|
71
|
-
f"Total_memory: {self.total_memory}, " \
|
|
72
|
-
f"BOPS: {self.bops}"
|
|
73
|
-
|
|
74
|
-
def get_kpi_dict(self) -> Dict[KPITarget, float]:
|
|
75
|
-
"""
|
|
76
|
-
Returns: a dictionary with the KPI object's values for each KPI target.
|
|
77
|
-
"""
|
|
78
|
-
return {KPITarget.WEIGHTS: self.weights_memory,
|
|
79
|
-
KPITarget.ACTIVATION: self.activation_memory,
|
|
80
|
-
KPITarget.TOTAL: self.total_memory,
|
|
81
|
-
KPITarget.BOPS: self.bops}
|
|
82
|
-
|
|
83
|
-
def set_kpi_by_target(self, kpis_mapping: Dict[KPITarget, float]):
|
|
84
|
-
"""
|
|
85
|
-
Setting a KPI object values for each KPI target in the given dictionary.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
kpis_mapping: A mapping from a KPITarget to a matching KPI value.
|
|
89
|
-
|
|
90
|
-
"""
|
|
91
|
-
self.weights_memory = kpis_mapping.get(KPITarget.WEIGHTS, np.inf)
|
|
92
|
-
self.activation_memory = kpis_mapping.get(KPITarget.ACTIVATION, np.inf)
|
|
93
|
-
self.total_memory = kpis_mapping.get(KPITarget.TOTAL, np.inf)
|
|
94
|
-
self.bops = kpis_mapping.get(KPITarget.BOPS, np.inf)
|
|
95
|
-
|
|
96
|
-
def holds_constraints(self, kpi: Any) -> bool:
|
|
97
|
-
"""
|
|
98
|
-
Checks whether the given KPI holds a set of KPI constraints defined by the currect KPI object.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
kpi: A KPI object to check if it holds the constraints.
|
|
102
|
-
|
|
103
|
-
Returns: True if all the given KPI values are not greater than the referenced KPI values.
|
|
104
|
-
|
|
105
|
-
"""
|
|
106
|
-
if not isinstance(kpi, KPI):
|
|
107
|
-
return False
|
|
108
|
-
|
|
109
|
-
return kpi.weights_memory <= self.weights_memory and \
|
|
110
|
-
kpi.activation_memory <= self.activation_memory and \
|
|
111
|
-
kpi.total_memory <= self.total_memory and \
|
|
112
|
-
kpi.bops <= self.bops
|