mct-nightly 1.11.0.20240321.357__py3-none-any.whl → 1.11.0.20240323.408__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/METADATA +17 -9
  2. {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/RECORD +152 -152
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +1 -1
  5. model_compression_toolkit/core/__init__.py +3 -3
  6. model_compression_toolkit/core/common/collectors/base_collector.py +2 -2
  7. model_compression_toolkit/core/common/data_loader.py +3 -3
  8. model_compression_toolkit/core/common/graph/base_graph.py +10 -13
  9. model_compression_toolkit/core/common/graph/base_node.py +3 -3
  10. model_compression_toolkit/core/common/graph/edge.py +2 -1
  11. model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py +2 -4
  12. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +2 -2
  13. model_compression_toolkit/core/common/hessian/hessian_info_service.py +2 -3
  14. model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py +3 -5
  15. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +1 -2
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +24 -23
  17. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +110 -112
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +114 -0
  19. model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_data.py → resource_utilization_tools/resource_utilization_data.py} +19 -19
  20. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_aggregation_methods.py +105 -0
  21. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +26 -0
  22. model_compression_toolkit/core/common/mixed_precision/{kpi_tools/kpi_methods.py → resource_utilization_tools/ru_methods.py} +61 -61
  23. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +75 -71
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -4
  25. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +34 -34
  26. model_compression_toolkit/core/common/model_collector.py +2 -2
  27. model_compression_toolkit/core/common/network_editors/actions.py +3 -3
  28. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +12 -12
  29. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +2 -2
  30. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +2 -2
  31. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -2
  32. model_compression_toolkit/core/common/pruning/memory_calculator.py +7 -7
  33. model_compression_toolkit/core/common/pruning/prune_graph.py +2 -3
  34. model_compression_toolkit/core/common/pruning/pruner.py +7 -7
  35. model_compression_toolkit/core/common/pruning/pruning_config.py +1 -1
  36. model_compression_toolkit/core/common/pruning/pruning_info.py +2 -2
  37. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +7 -4
  38. model_compression_toolkit/core/common/quantization/node_quantization_config.py +3 -1
  39. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +4 -2
  40. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +4 -6
  41. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +2 -4
  42. model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py +1 -1
  43. model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py +8 -6
  44. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +2 -2
  45. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +4 -7
  47. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -3
  48. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -1
  49. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
  50. model_compression_toolkit/core/common/user_info.py +1 -1
  51. model_compression_toolkit/core/keras/back2framework/factory_model_builder.py +3 -3
  52. model_compression_toolkit/core/keras/back2framework/instance_builder.py +2 -2
  53. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +4 -8
  54. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -2
  55. model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +2 -2
  56. model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +1 -1
  57. model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py +1 -1
  58. model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  59. model_compression_toolkit/core/keras/hessian/activation_trace_hessian_calculator_keras.py +3 -3
  60. model_compression_toolkit/core/keras/hessian/trace_hessian_calculator_keras.py +1 -2
  61. model_compression_toolkit/core/keras/hessian/weights_trace_hessian_calculator_keras.py +5 -6
  62. model_compression_toolkit/core/keras/keras_implementation.py +1 -1
  63. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -1
  64. model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py +2 -4
  65. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +1 -1
  66. model_compression_toolkit/core/keras/quantizer/fake_quant_builder.py +7 -7
  67. model_compression_toolkit/core/keras/reader/common.py +2 -2
  68. model_compression_toolkit/core/keras/reader/node_builder.py +1 -1
  69. model_compression_toolkit/core/keras/{kpi_data_facade.py → resource_utilization_data_facade.py} +25 -24
  70. model_compression_toolkit/core/keras/tf_tensor_numpy.py +4 -2
  71. model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py +3 -3
  72. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +6 -11
  73. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +2 -2
  74. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py +1 -1
  75. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py +1 -1
  76. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +5 -5
  77. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +1 -1
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +1 -1
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py +1 -1
  80. model_compression_toolkit/core/pytorch/hessian/activation_trace_hessian_calculator_pytorch.py +3 -7
  81. model_compression_toolkit/core/pytorch/hessian/trace_hessian_calculator_pytorch.py +1 -2
  82. model_compression_toolkit/core/pytorch/hessian/weights_trace_hessian_calculator_pytorch.py +2 -2
  83. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -1
  84. model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py +1 -2
  85. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +3 -3
  86. model_compression_toolkit/core/pytorch/pytorch_implementation.py +1 -1
  87. model_compression_toolkit/core/pytorch/reader/graph_builders.py +5 -7
  88. model_compression_toolkit/core/pytorch/reader/reader.py +2 -2
  89. model_compression_toolkit/core/pytorch/{kpi_data_facade.py → resource_utilization_data_facade.py} +24 -22
  90. model_compression_toolkit/core/pytorch/utils.py +3 -2
  91. model_compression_toolkit/core/runner.py +43 -42
  92. model_compression_toolkit/data_generation/common/data_generation.py +18 -18
  93. model_compression_toolkit/data_generation/common/model_info_exctractors.py +1 -1
  94. model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -10
  95. model_compression_toolkit/data_generation/keras/model_info_exctractors.py +2 -1
  96. model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py +2 -1
  97. model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +2 -4
  98. model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +2 -1
  99. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +8 -11
  100. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +1 -1
  101. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -3
  102. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -3
  103. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +8 -4
  104. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +2 -2
  105. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +7 -8
  106. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +19 -12
  107. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +2 -2
  108. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +10 -11
  109. model_compression_toolkit/gptq/common/gptq_graph.py +3 -3
  110. model_compression_toolkit/gptq/common/gptq_training.py +14 -12
  111. model_compression_toolkit/gptq/keras/gptq_training.py +10 -8
  112. model_compression_toolkit/gptq/keras/graph_info.py +1 -1
  113. model_compression_toolkit/gptq/keras/quantization_facade.py +15 -17
  114. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
  115. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +1 -2
  116. model_compression_toolkit/gptq/pytorch/gptq_training.py +10 -8
  117. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -1
  118. model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -13
  119. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +3 -4
  120. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +1 -2
  121. model_compression_toolkit/logger.py +1 -13
  122. model_compression_toolkit/pruning/keras/pruning_facade.py +11 -12
  123. model_compression_toolkit/pruning/pytorch/pruning_facade.py +11 -12
  124. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -14
  125. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -8
  126. model_compression_toolkit/qat/keras/quantization_facade.py +20 -22
  127. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +2 -3
  128. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +1 -1
  129. model_compression_toolkit/qat/pytorch/quantization_facade.py +12 -14
  130. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -3
  131. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/immutable.py +4 -2
  133. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +4 -8
  134. model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +1 -1
  135. model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +43 -8
  136. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +13 -18
  137. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -2
  138. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py +2 -2
  139. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py +2 -1
  140. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +5 -5
  141. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +1 -2
  142. model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py +13 -13
  143. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +14 -7
  144. model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py +5 -5
  145. model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +2 -3
  146. model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -5
  147. model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +3 -4
  148. model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +3 -3
  149. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py +0 -112
  150. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py +0 -105
  151. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py +0 -26
  152. {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/LICENSE.md +0 -0
  153. {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/WHEEL +0 -0
  154. {mct_nightly-1.11.0.20240321.357.dist-info → mct_nightly-1.11.0.20240323.408.dist-info}/top_level.txt +0 -0
  155. /model_compression_toolkit/core/common/mixed_precision/{kpi_tools → resource_utilization_tools}/__init__.py +0 -0
@@ -14,37 +14,72 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from typing import Any
17
+ from typing import Any, List, Union
18
18
 
19
- from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat
19
+ from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
20
+ OperatorsSet
20
21
  from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
21
22
 
22
23
 
23
24
  class Fusing(TargetPlatformModelComponent):
25
+ """
26
+ Fusing defines a list of operators that should be combined and treated as a single operator,
27
+ hence no quantization is applied between them.
28
+ """
24
29
 
25
- def __init__(self, operator_groups_list, name=None):
30
+ def __init__(self,
31
+ operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
32
+ name: str = None):
33
+ """
34
+ Args:
35
+ operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
36
+ name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
37
+ """
26
38
  assert isinstance(operator_groups_list,
27
39
  list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
28
40
  assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
41
+
42
+ # Generate a name from the operator groups if no name is provided
29
43
  if name is None:
30
44
  name = '_'.join([x.name for x in operator_groups_list])
45
+
31
46
  super().__init__(name)
32
47
  self.operator_groups_list = operator_groups_list
33
48
 
34
- def contains(self, other: Any):
49
+ def contains(self, other: Any) -> bool:
50
+ """
51
+ Determines if the current Fusing instance contains another Fusing instance.
52
+
53
+ Args:
54
+ other: The other Fusing instance to check against.
55
+
56
+ Returns:
57
+ A boolean indicating whether the other instance is contained within this one.
58
+ """
35
59
  if not isinstance(other, Fusing):
36
60
  return False
61
+
62
+ # Check for containment by comparing operator groups
37
63
  for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
38
64
  for j in range(len(other.operator_groups_list)):
39
- if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
65
+ if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
66
+ isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
67
+ other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
40
68
  break
41
69
  else:
70
+ # If all checks pass, the other Fusing instance is contained
42
71
  return True
72
+ # Other Fusing instance is not contained
43
73
  return False
44
74
 
45
-
46
75
  def get_info(self):
76
+ """
77
+ Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
78
+
79
+ Returns:
80
+ A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
81
+ or just the sequence of operator groups if no name is set.
82
+ """
47
83
  if self.name is not None:
48
84
  return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
49
- return ' -> '.join([x.name for x in self.operator_groups_list])
50
-
85
+ return ' -> '.join([x.name for x in self.operator_groups_list])
@@ -124,7 +124,7 @@ class OpQuantizationConfig:
124
124
 
125
125
  Args:
126
126
  default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation.
127
- attr_weights_configs_mapping (dict): A mapping between an op attribute name and its quantization configuration.
127
+ attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
128
128
  activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
129
129
  activation_n_bits (int): Number of bits to quantize the activations.
130
130
  enable_activation_quantization (bool): Whether to quantize the model activations or not.
@@ -215,24 +215,19 @@ class QuantizationConfigOptions(object):
215
215
  """
216
216
 
217
217
  assert isinstance(quantization_config_list,
218
- list), f'QuantizationConfigOptions options list should be of type list, but is: ' \
219
- f'{type(quantization_config_list)}'
220
- assert len(quantization_config_list) > 0, f'Options list can not be empty'
218
+ list), f'\'QuantizationConfigOptions\' options list must be a list, but received: {type(quantization_config_list)}.'
219
+ assert len(quantization_config_list) > 0, f'Options list can not be empty.'
221
220
  for cfg in quantization_config_list:
222
- assert isinstance(cfg, OpQuantizationConfig), f'Options should be a list of QuantizationConfig objects, ' \
223
- f'but found an object type: {type(cfg)}'
221
+ assert isinstance(cfg, OpQuantizationConfig), f'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: {type(cfg)}.'
224
222
  self.quantization_config_list = quantization_config_list
225
223
  if len(quantization_config_list) > 1:
226
- assert base_config is not None, f'When quantization config options contains more than one configuration, ' \
227
- f'a base_config must be passed for non-mixed-precision optimization process'
228
- assert base_config in quantization_config_list, f"base_config must be in the given quantization config " \
229
- f"list of options"
224
+ assert base_config is not None, f'For multiple configurations, a \'base_config\' is required for non-mixed-precision optimization.'
225
+ assert base_config in quantization_config_list, f"\'base_config\' must be included in the quantization config options list."
230
226
  self.base_config = base_config
231
227
  elif len(quantization_config_list) == 1:
232
228
  self.base_config = quantization_config_list[0]
233
229
  else:
234
- raise Exception("QuantizationConfigOptions must have at least one OpQuantizationConfig "
235
- "defined in its options list, but list is empty")
230
+ Logger.critical("\'QuantizationConfigOptions\' requires at least one \'OpQuantizationConfig\'; the provided list is empty.")
236
231
 
237
232
  def __eq__(self, other):
238
233
  """
@@ -280,13 +275,13 @@ class QuantizationConfigOptions(object):
280
275
  attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
281
276
  else:
282
277
  if not isinstance(attrs, List):
283
- Logger.error(f"Expecting a list of attribute but got {type(attrs)}.")
278
+ Logger.critical(f"Expected a list of attributes but received {type(attrs)}.")
284
279
  attrs_to_update = attrs
285
280
 
286
281
  for attr in attrs_to_update:
287
282
  if qc.attr_weights_configs_mapping.get(attr) is None:
288
- Logger.error(f'Edit attributes is possible only for existing attributes '
289
- f'in the configuration weights config mapping, but {attr} is not an attribute of {qc}.')
283
+ Logger.critical(f'Editing attributes is only possible for existing attributes in the configuration\'s '
284
+ f'weights config mapping; {attr} does not exist in {qc}.')
290
285
  self.__edit_quantization_configuration(qc.attr_weights_configs_mapping[attr], kwargs)
291
286
  return qc_options
292
287
 
@@ -312,7 +307,7 @@ class QuantizationConfigOptions(object):
312
307
  for attr in list(qc.attr_weights_configs_mapping.keys()):
313
308
  new_key = layer_attrs_mapping.get(attr)
314
309
  if new_key is None:
315
- Logger.error(f"Attribute {attr} does not exist in the given attribute mapping.")
310
+ Logger.critical(f"Attribute \'{attr}\' does not exist in the provided attribute mapping.")
316
311
 
317
312
  new_attr_mapping[new_key] = qc.attr_weights_configs_mapping.pop(attr)
318
313
 
@@ -323,8 +318,8 @@ class QuantizationConfigOptions(object):
323
318
  def __edit_quantization_configuration(self, qc, kwargs):
324
319
  for k, v in kwargs.items():
325
320
  assert hasattr(qc,
326
- k), f'Edit attributes is possible only for existing attributes in configuration, ' \
327
- f'but {k} is not an attribute of {qc}'
321
+ k), (f'Editing is only possible for existing attributes in the configuration; '
322
+ f'{k} is not an attribute of {qc}.')
328
323
  setattr(qc, k, v)
329
324
 
330
325
  def get_info(self):
@@ -156,7 +156,7 @@ class TargetPlatformModel(ImmutableClass):
156
156
  elif isinstance(tp_model_component, OperatorsSetBase):
157
157
  self.operator_set.append(tp_model_component)
158
158
  else:
159
- raise Exception(f'Trying to append an unfamiliar TargetPlatformModelComponent of type: {type(tp_model_component)}')
159
+ Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
160
160
 
161
161
  def __enter__(self):
162
162
  """
@@ -192,7 +192,7 @@ class TargetPlatformModel(ImmutableClass):
192
192
  """
193
193
  opsets_names = [op.name for op in self.operator_set]
194
194
  if (len(set(opsets_names)) != len(opsets_names)):
195
- Logger.error(f'OperatorsSet must have unique names')
195
+ Logger.critical(f'Operator Sets must have unique names.')
196
196
 
197
197
  def get_default_config(self) -> OpQuantizationConfig:
198
198
  """
@@ -87,7 +87,7 @@ class AttributeFilter(Filter):
87
87
  """
88
88
 
89
89
  if not isinstance(other, AttributeFilter):
90
- Logger.error("Not an attribute filter. Can not run an OR operation.") # pragma: no cover
90
+ Logger.critical("Not an attribute filter. Cannot perform an 'OR' operation.") # pragma: no cover
91
91
  return OrAttributeFilter(self, other)
92
92
 
93
93
  def __and__(self, other: Any):
@@ -101,7 +101,7 @@ class AttributeFilter(Filter):
101
101
  AndAttributeFilter that filters with AND between the current AttributeFilter and the passed AttributeFilter.
102
102
  """
103
103
  if not isinstance(other, AttributeFilter):
104
- Logger.error("Not an attribute filter. Can not run an AND operation.") # pragma: no cover
104
+ Logger.critical("Not an attribute filter. Can not perform an 'AND' operation.") # pragma: no cover
105
105
  return AndAttributeFilter(self, other)
106
106
 
107
107
  def match(self,
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.logger import Logger
15
16
 
16
17
 
17
18
  def get_current_tpc():
@@ -38,7 +39,7 @@ class _CurrentTPC(object):
38
39
 
39
40
  """
40
41
  if self.tpc is None:
41
- raise Exception('TargetPlatformCapabilities is not initialized.')
42
+ Logger.critical("'TargetPlatformCapabilities' (TPC) instance is not initialized.")
42
43
  return self.tpc
43
44
 
44
45
  def reset(self):
@@ -20,6 +20,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.targ
20
20
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
21
21
  from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
22
22
  OperatorsSetBase
23
+ from model_compression_toolkit import DefaultDict
23
24
 
24
25
 
25
26
  class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
@@ -29,15 +30,14 @@ class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
29
30
  def __init__(self,
30
31
  op_set_name: str,
31
32
  layers: List[Any],
32
- attr_mapping: Dict[str, Any] = None):
33
+ attr_mapping: Dict[str, DefaultDict] = None):
33
34
  """
34
35
 
35
36
  Args:
36
37
  op_set_name (str): Name of OperatorsSet to associate with layers.
37
38
  layers (List[Any]): List of layers/FilterLayerParams to associate with OperatorsSet.
38
- attr_mapping (dict): A mapping between a general attribute name to a DefaultDict that maps a layer type
39
- to the layer's framework name of this attribute (the dictionary type is not specified to handle circular
40
- dependency).
39
+ attr_mapping (Dict[str, DefaultDict]): A mapping between a general attribute name to a DefaultDict that maps a layer type to the layer's framework name of this attribute.
40
+
41
41
  """
42
42
  self.layers = layers
43
43
  self.attr_mapping = attr_mapping
@@ -147,7 +147,7 @@ class OperationsToLayers:
147
147
  for layer in ops2layers.layers:
148
148
  qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name)
149
149
  if layer in existing_layers:
150
- Logger.error(f'Found layer {layer.__name__} in more than one '
150
+ Logger.critical(f'Found layer {layer.__name__} in more than one '
151
151
  f'OperatorsSet') # pragma: no cover
152
152
  else:
153
153
  existing_layers.update({layer: qco_by_opset_name})
@@ -139,8 +139,7 @@ class TargetPlatformCapabilities(ImmutableClass):
139
139
  if isinstance(tpc_component, OperationsSetToLayers):
140
140
  self.op_sets_to_layers += tpc_component
141
141
  else:
142
- Logger.error(f'Trying to append an unfamiliar TargetPlatformCapabilitiesComponent of type: '
143
- f'{type(tpc_component)}') # pragma: no cover
142
+ Logger.critical(f"Attempt to append an unrecognized 'TargetPlatformCapabilitiesComponent' of type: '{type(tpc_component)}'. Ensure the component is compatible.") # pragma: no cover
144
143
 
145
144
  def __enter__(self):
146
145
  """
@@ -55,9 +55,9 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
55
55
  for i, (k, v) in enumerate(self.get_sig().parameters.items()):
56
56
  if i == 0:
57
57
  if v.annotation not in [TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]:
58
- Logger.error(f"First parameter must be either TrainableQuantizerWeightsConfig or TrainableQuantizerActivationConfig") # pragma: no cover
58
+ Logger.critical(f"The first parameter must be either 'TrainableQuantizerWeightsConfig' or 'TrainableQuantizerActivationConfig'.") # pragma: no cover
59
59
  elif v.default is v.empty:
60
- Logger.error(f"Parameter {k} doesn't have a default value") # pragma: no cover
60
+ Logger.critical(f"Parameter '{k}' lacks a default value.") # pragma: no cover
61
61
 
62
62
  super(BaseTrainableQuantizer, self).__init__()
63
63
  self.quantization_config = quantization_config
@@ -67,22 +67,22 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
67
67
  static_quantization_target = getattr(self, QUANTIZATION_TARGET, None)
68
68
 
69
69
  if static_quantization_method is None or static_quantization_target is None:
70
- Logger.error("A quantizer class that inherit from BaseTrainableQuantizer is not defined appropriately."
71
- "Either it misses the @mark_quantizer decorator or the decorator is not used correctly.")
70
+ Logger.critical("Quantizer class inheriting from 'BaseTrainableQuantizer' is improperly defined. "
71
+ "Ensure it includes the '@mark_quantizer' decorator and is correctly applied.")
72
72
 
73
73
  if static_quantization_target == QuantizationTarget.Weights:
74
74
  self.validate_weights()
75
75
  if self.quantization_config.weights_quantization_method not in static_quantization_method:
76
- Logger.error(
77
- f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.weights_quantization_method}')
76
+ Logger.critical(
77
+ f"Quantization method mismatch. Expected methods: {static_quantization_method}, received: {self.quantization_config.weights_quantization_method}.")
78
78
  elif static_quantization_target == QuantizationTarget.Activation:
79
79
  self.validate_activation()
80
80
  if self.quantization_config.activation_quantization_method not in static_quantization_method:
81
- Logger.error(
82
- f'Quantization method mismatch expected: {static_quantization_method} and got {self.quantization_config.activation_quantization_method}')
81
+ Logger.critical(
82
+ f"Quantization method mismatch. Expected methods: {static_quantization_method}, received: {self.quantization_config.activation_quantization_method}.")
83
83
  else:
84
- Logger.error(
85
- f'Unknown Quantization Part:{static_quantization_target}') # pragma: no cover
84
+ Logger.critical(
85
+ f"Unrecognized 'QuantizationTarget': {static_quantization_target}.") # pragma: no cover
86
86
 
87
87
  self.quantizer_parameters = {}
88
88
 
@@ -145,7 +145,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
145
145
 
146
146
  """
147
147
  if self.activation_quantization() or not self.weights_quantization():
148
- Logger.error(f'Expect weight quantization got activation')
148
+ Logger.critical(f'Expected weight quantization configuration; received activation quantization instead.')
149
149
 
150
150
  def validate_activation(self) -> None:
151
151
  """
@@ -153,7 +153,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
153
153
 
154
154
  """
155
155
  if not self.activation_quantization() or self.weights_quantization():
156
- Logger.error(f'Expect activation quantization got weight')
156
+ Logger.critical(f'Expected activation quantization configuration; received weight quantization instead.')
157
157
 
158
158
  def convert2inferable(self) -> BaseInferableQuantizer:
159
159
  """
@@ -183,7 +183,7 @@ class BaseTrainableQuantizer(BaseInferableQuantizer):
183
183
  if name in self.quantizer_parameters:
184
184
  return self.quantizer_parameters[name][VAR]
185
185
  else:
186
- Logger.error(f'Variable {name} is not exist in quantizers parameters!') # pragma: no cover
186
+ Logger.critical(f"Variable '{name}' does not exist in quantizer parameters.") # pragma: no cover
187
187
 
188
188
 
189
189
  @abstractmethod
@@ -36,7 +36,9 @@ def get_trainable_quantizer_weights_config(
36
36
  TrainableQuantizerWeightsConfig: an object that contains the quantizer configuration
37
37
  """
38
38
  if n.final_weights_quantization_cfg is None:
39
- Logger.error(f'Node must have final_weights_quantization_cfg in order to build quantizer configuration') # pragma: no cover
39
+ Logger.critical(
40
+ "The node requires a 'final_weights_quantization_cfg' configuration to build a "
41
+ "quantizer. Please ensure this configuration is set for the node.")# pragma: no cover
40
42
 
41
43
  final_node_cfg = n.final_weights_quantization_cfg
42
44
  final_attr_cfg = final_node_cfg.get_attr_config(attr_name)
@@ -65,7 +67,9 @@ def get_trainable_quantizer_activation_config(
65
67
  TrainableQuantizerActivationConfig - an object that contains the quantizer configuration
66
68
  """
67
69
  if n.final_activation_quantization_cfg is None:
68
- Logger.error(f'Node must have final_activation_quantization_cfg in order to build quantizer configuration') # pragma: no cover
70
+ Logger.critical(
71
+ "The node requires a 'final_activation_quantization_cfg' configuration to build a "
72
+ "quantizer. Please ensure this configuration is set for the node.")# pragma: no cover
69
73
 
70
74
  final_cfg = n.final_activation_quantization_cfg
71
75
  return TrainableQuantizerActivationConfig(final_cfg.activation_quantization_method,
@@ -93,17 +97,20 @@ def get_trainable_quantizer_quantization_candidates(n: BaseNode, attr: str = Non
93
97
  if attr is not None:
94
98
  # all candidates must have the same weights quantization method
95
99
  weights_quantization_methods = set([cfg.weights_quantization_cfg.get_attr_config(attr).weights_quantization_method
96
- for cfg in n.candidates_quantization_cfg])
100
+ for cfg in n.candidates_quantization_cfg])
97
101
  if len(weights_quantization_methods) > 1:
98
- Logger.error(f'Unsupported candidates_quantization_cfg with different weights quantization methods: '
99
- f'{weights_quantization_methods}') # pragma: no cover
102
+ Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent weights "
103
+ f"quantization methods detected: {weights_quantization_methods}. "
104
+ f"Trainable quantizer requires all candidates to have the same weights "
105
+ f"quantization method.") # pragma: no cover
100
106
 
101
107
  # all candidates must have the same activation quantization method
102
108
  activation_quantization_methods = set([cfg.activation_quantization_cfg.activation_quantization_method
103
109
  for cfg in n.candidates_quantization_cfg])
104
110
  if len(activation_quantization_methods) > 1:
105
- Logger.error(f'Unsupported candidates_quantization_cfg with different activation quantization methods: '
106
- f'{activation_quantization_methods}') # pragma: no cover
111
+ Logger.critical(f"Invalid 'candidates_quantization_cfg': Inconsistent activation quantization "
112
+ f"methods detected: {activation_quantization_methods}. "
113
+ f"Trainable quantizer requires all candidates to have the same activation quantization method.")# pragma: no cover
107
114
 
108
115
  # get unique lists of candidates
109
116
  unique_weights_candidates = n.get_unique_weights_candidates(attr)
@@ -44,7 +44,7 @@ def get_trainable_quantizer_class(quant_target: QuantizationTarget,
44
44
  """
45
45
  qat_quantizer_classes = get_all_subclasses(quantizer_base_class)
46
46
  if len(qat_quantizer_classes) == 0:
47
- Logger.error(f"No quantizers were found that inherit from {quantizer_base_class}.") # pragma: no cover
47
+ Logger.critical(f"No quantizer classes inherited from {quantizer_base_class} were detected.") # pragma: no cover
48
48
 
49
49
  filtered_quantizers = list(filter(lambda q_class: getattr(q_class, QUANTIZATION_TARGET, None) is not None and
50
50
  getattr(q_class, QUANTIZATION_TARGET) == quant_target and
@@ -54,9 +54,9 @@ def get_trainable_quantizer_class(quant_target: QuantizationTarget,
54
54
  qat_quantizer_classes))
55
55
 
56
56
  if len(filtered_quantizers) != 1:
57
- Logger.error(f"Found {len(filtered_quantizers)} quantizer for target {quant_target.value} " # pragma: no cover
58
- f"that matches the requested quantization method {quant_method.name} and "
59
- f"quantizer type {quantizer_id.value} but there should be exactly one."
60
- f"The possible quantizers that were found are {filtered_quantizers}.")
57
+ Logger.critical(f"Found {len(filtered_quantizers)} quantizers for target {quant_target.value}, "
58
+ f"matching the requested quantization method {quant_method.name} and "
59
+ f"quantizer type {quantizer_id.value}, but exactly one is required. "
60
+ f"Identified quantizers: {filtered_quantizers}.")
61
61
 
62
62
  return filtered_quantizers[0]
@@ -86,6 +86,5 @@ else:
86
86
  quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
87
87
 
88
88
  super().__init__(quantization_config)
89
- Logger.critical('Installing tensorflow is mandatory '
90
- 'when using BaseKerasQuantizer. '
91
- 'Could not find Tensorflow package.') # pragma: no cover
89
+ Logger.critical("Tensorflow must be installed to use BaseKerasTrainableQuantizer. "
90
+ "The 'tensorflow' package is missing.") # pragma: no cover
@@ -48,8 +48,8 @@ if FOUND_TF:
48
48
  KerasTrainableQuantizationWrapper.__name__: KerasTrainableQuantizationWrapper})
49
49
  all_trainable_names = list(qi_trainable_custom_objects.keys())
50
50
  if len(set(all_trainable_names)) < len(all_trainable_names):
51
- Logger.error(f"Found multiple quantizers with the same name that inherit from BaseKerasTrainableQuantizer"
52
- f"while trying to load a model.")
51
+ Logger.critical("Found multiple quantizers with identical names inheriting from "
52
+ "'BaseKerasTrainableQuantizer' while trying to load a model.")
53
53
 
54
54
  qi_custom_objects = {**qi_trainable_custom_objects}
55
55
 
@@ -72,6 +72,5 @@ else:
72
72
  Returns: A keras Model
73
73
 
74
74
  """
75
- Logger.critical('Installing tensorflow is mandatory '
76
- 'when using keras_load_quantized_model. '
77
- 'Could not find Tensorflow package.') # pragma: no cover
75
+ Logger.critical("Tensorflow must be installed to use keras_load_quantized_model. "
76
+ "The 'tensorflow' package is missing.") # pragma: no cover
@@ -91,7 +91,7 @@ if FOUND_TF:
91
91
  layer_weights_list[weight_keys.index(_weight_name(w.name))] = w
92
92
  # Verify all the weights in the list are ready. The "set_weights" method expects all the layer's weights
93
93
  if not all(w is not None for w in layer_weights_list):
94
- Logger.error(f'Not all weights are set for layer {self.layer.name}')
94
+ Logger.critical(f"Not all weights are set for layer '{self.layer.name}'")
95
95
  assert all(w is not None for w in layer_weights_list)
96
96
  inferable_quantizers_wrapper.set_weights(layer_weights_list)
97
97
 
@@ -110,6 +110,5 @@ else:
110
110
  layer: A keras layer.
111
111
  weights_quantizers: A dictionary between a weight's name to its quantizer.
112
112
  """
113
- Logger.critical('Installing tensorflow is mandatory '
114
- 'when using KerasTrainableQuantizationWrapper. '
115
- 'Could not find Tensorflow package.') # pragma: no cover
113
+ Logger.critical("Tensorflow must be installed to use KerasTrainableQuantizationWrapper. "
114
+ "The 'tensorflow' package is missing.") # pragma: no cover
@@ -60,6 +60,6 @@ else:
60
60
  def __init__(self,
61
61
  quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
62
62
  super().__init__(quantization_config)
63
- Logger.critical('Installing Pytorch is mandatory '
64
- 'when using BasePytorchTrainableQuantizer. '
65
- 'Could not find torch package.') # pragma: no cover
63
+ Logger.critical("PyTorch must be installed to use 'BasePytorchTrainableQuantizer'. "
64
+ "The 'torch' package is missing.") # pragma: no cover
65
+
@@ -1,112 +0,0 @@
1
- # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from enum import Enum
16
- from typing import Dict, Any
17
-
18
- import numpy as np
19
-
20
-
21
- class KPITarget(Enum):
22
- """
23
- Targets for which we define KPIs metrics for mixed-precision search.
24
- For each target that we care to consider in a mixed-precision search, there should be defined a set of
25
- kpi computation function, kpi aggregation function, and kpi target (within a KPI object).
26
-
27
- Whenever adding a kpi metric to KPI class we should add a matching target to this enum.
28
-
29
- WEIGHTS - Weights memory KPI metric.
30
-
31
- ACTIVATION - Activation memory KPI metric.
32
-
33
- TOTAL - Total memory KPI metric.
34
-
35
- BOPS - Total Bit-Operations KPI Metric.
36
-
37
- """
38
-
39
- WEIGHTS = 'weights'
40
- ACTIVATION = 'activation'
41
- TOTAL = 'total'
42
- BOPS = 'bops'
43
-
44
-
45
- class KPI:
46
- """
47
- Class to represent measurements of performance.
48
- """
49
-
50
- def __init__(self,
51
- weights_memory: float = np.inf,
52
- activation_memory: float = np.inf,
53
- total_memory: float = np.inf,
54
- bops: float = np.inf):
55
- """
56
-
57
- Args:
58
- weights_memory: Memory of a model's weights in bytes. Note that this includes only coefficients that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value, while the bias will not).
59
- activation_memory: Memory of a model's activation in bytes, according to the given activation kpi metric.
60
- total_memory: The sum of model's activation and weights memory in bytes, according to the given total kpi metric.
61
- bops: The total bit-operations in the model.
62
- """
63
- self.weights_memory = weights_memory
64
- self.activation_memory = activation_memory
65
- self.total_memory = total_memory
66
- self.bops = bops
67
-
68
- def __repr__(self):
69
- return f"Weights_memory: {self.weights_memory}, " \
70
- f"Activation_memory: {self.activation_memory}, " \
71
- f"Total_memory: {self.total_memory}, " \
72
- f"BOPS: {self.bops}"
73
-
74
- def get_kpi_dict(self) -> Dict[KPITarget, float]:
75
- """
76
- Returns: a dictionary with the KPI object's values for each KPI target.
77
- """
78
- return {KPITarget.WEIGHTS: self.weights_memory,
79
- KPITarget.ACTIVATION: self.activation_memory,
80
- KPITarget.TOTAL: self.total_memory,
81
- KPITarget.BOPS: self.bops}
82
-
83
- def set_kpi_by_target(self, kpis_mapping: Dict[KPITarget, float]):
84
- """
85
- Setting a KPI object values for each KPI target in the given dictionary.
86
-
87
- Args:
88
- kpis_mapping: A mapping from a KPITarget to a matching KPI value.
89
-
90
- """
91
- self.weights_memory = kpis_mapping.get(KPITarget.WEIGHTS, np.inf)
92
- self.activation_memory = kpis_mapping.get(KPITarget.ACTIVATION, np.inf)
93
- self.total_memory = kpis_mapping.get(KPITarget.TOTAL, np.inf)
94
- self.bops = kpis_mapping.get(KPITarget.BOPS, np.inf)
95
-
96
- def holds_constraints(self, kpi: Any) -> bool:
97
- """
98
- Checks whether the given KPI holds a set of KPI constraints defined by the currect KPI object.
99
-
100
- Args:
101
- kpi: A KPI object to check if it holds the constraints.
102
-
103
- Returns: True if all the given KPI values are not greater than the referenced KPI values.
104
-
105
- """
106
- if not isinstance(kpi, KPI):
107
- return False
108
-
109
- return kpi.weights_memory <= self.weights_memory and \
110
- kpi.activation_memory <= self.activation_memory and \
111
- kpi.total_memory <= self.total_memory and \
112
- kpi.bops <= self.bops