mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  92. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  93. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  94. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  95. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  96. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  97. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  98. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  99. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  100. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  101. model_compression_toolkit/gptq/runner.py +1 -7
  102. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  103. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  104. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  105. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  106. model_compression_toolkit/ptq/runner.py +1 -4
  107. model_compression_toolkit/qat/common/qat_config.py +2 -6
  108. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  109. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  110. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  111. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  112. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  113. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  114. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  115. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  116. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  117. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  118. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
  119. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
  120. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ from model_compression_toolkit.core import FrameworkInfo, common
23
23
  from model_compression_toolkit.core.common import BaseNode
24
24
  from model_compression_toolkit.core.common.user_info import UserInformation
25
25
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
26
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
27
26
  from model_compression_toolkit.core.pytorch.mixed_precision.configurable_activation_quantizer import \
28
27
  ConfigurableActivationQuantizer
29
28
  from model_compression_toolkit.core.pytorch.mixed_precision.configurable_weights_quantizer import \
@@ -38,14 +37,12 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
38
37
  def __init__(self,
39
38
  graph: common.Graph,
40
39
  append2output=None,
41
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
42
40
  return_float_outputs: bool = False):
43
41
  """
44
42
 
45
43
  Args:
46
44
  graph: Graph to build the model from.
47
45
  append2output: Nodes to append to model's output.
48
- fw_info: Information about the specific framework of the model that is built.
49
46
  return_float_outputs: Whether the model returns float tensors or not.
50
47
  """
51
48
 
@@ -53,7 +50,6 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
53
50
 
54
51
  super().__init__(graph,
55
52
  append2output,
56
- fw_info,
57
53
  return_float_outputs,
58
54
  wrapper=self.mixed_precision_wrapper,
59
55
  get_activation_quantizer_holder_fn=self.mixed_precision_activation_holder)
@@ -77,17 +73,16 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
77
73
  ValueError: if kernel attribute is quantized but not configurable.
78
74
  """
79
75
 
80
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
81
- if kernel_attr is None or not n.is_weights_quantization_enabled(kernel_attr):
76
+ if n.kernel_attr is None or not n.is_weights_quantization_enabled(n.kernel_attr):
82
77
  return layer
83
- if not n.is_configurable_weight(kernel_attr): # pragma: no cover
78
+ if not n.is_configurable_weight(n.kernel_attr): # pragma: no cover
84
79
  raise ValueError(f'Weight wrapper is not expected to be created for non-configurable weight of node {n}.')
85
80
  return PytorchQuantizationWrapper(layer,
86
81
  weights_quantizers={
87
- kernel_attr: ConfigurableWeightsQuantizer(
82
+ n.kernel_attr: ConfigurableWeightsQuantizer(
88
83
  **self._get_weights_configurable_quantizer_kwargs(n,
89
- kernel_attr),
90
- kernel_attr=kernel_attr)})
84
+ n.kernel_attr),
85
+ kernel_attr=n.kernel_attr)})
91
86
 
92
87
  def _get_weights_configurable_quantizer_kwargs(self, n: BaseNode, attr: str) -> Dict[str, Any]:
93
88
  """
@@ -147,14 +142,13 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
147
142
  # activation number of bits (in reversed order).
148
143
  # since only kernel attribute is quantized in weights mixed precision,
149
144
  # if the node doesn't have a kernel attribute, we only sort by activation_n_bits.
150
- n.sort_node_candidates(self.fw_info)
145
+ n.sort_node_candidates()
151
146
 
152
147
  max_candidate_idx = n.find_max_candidate_index()
153
148
 
154
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
155
149
  activation_quantizers = [ConfigurableActivationQuantizer(**{'node_q_cfg': node_q_cfg_candidates,
156
150
  'max_candidate_idx': max_candidate_idx,
157
- 'kernel_attr': kernel_attr})] \
151
+ 'kernel_attr': n.kernel_attr})] \
158
152
  * num_of_outputs
159
153
 
160
154
  # Holder by definition uses a single quantizer for the activation quantization
@@ -177,7 +171,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
177
171
  # creating a mapping between graph nodes and model's layers for mixed precision configurability
178
172
  model_layers = dict(model.named_children())
179
173
  conf_node2layers = {n.name: self._find_layers_in_model_by_node(n, model_layers)
180
- for n in self.graph.get_configurable_sorted_nodes(self.fw_info)}
174
+ for n in self.graph.get_configurable_sorted_nodes()}
181
175
 
182
176
  return model, user_info, conf_node2layers
183
177
 
@@ -230,8 +224,7 @@ class MixedPrecisionPyTorchModelBuilder(PyTorchModelBuilder):
230
224
 
231
225
  """
232
226
  # Only layers with kernel op are considered weights configurable
233
- kernel_attr = self.fw_info.get_kernel_op_attributes(n.type)[0]
234
- weights_quant = False if kernel_attr is None else n.is_weights_quantization_enabled(kernel_attr)
227
+ weights_quant = False if n.kernel_attr is None else n.is_weights_quantization_enabled(n.kernel_attr)
235
228
  act_quant = n.is_activation_quantization_enabled()
236
229
 
237
230
  if weights_quant and not act_quant:
@@ -30,7 +30,6 @@ from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
30
30
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
31
31
  from model_compression_toolkit.core.common.user_info import UserInformation
32
32
  from model_compression_toolkit.core.pytorch.back2framework.instance_builder import node_builder
33
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
34
33
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
35
34
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
36
35
  from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
@@ -364,7 +363,7 @@ class PytorchModel(torch.nn.Module):
364
363
  """
365
364
  node_to_output_tensors_dict = dict()
366
365
  node_to_output_tensors_dict_float = dict()
367
- configurable_nodes = self.graph.get_configurable_sorted_nodes_names(DEFAULT_PYTORCH_INFO)
366
+ configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
368
367
  for node in self.node_sort:
369
368
  op_func = self._get_op_func(node, configurable_nodes)
370
369
  input_tensors = _build_input_tensors_list(node,
@@ -440,7 +439,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
440
439
  def __init__(self,
441
440
  graph: common.Graph,
442
441
  append2output=None,
443
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
444
442
  return_float_outputs: bool = False,
445
443
  wrapper: Callable = None,
446
444
  get_activation_quantizer_holder_fn: Callable = None):
@@ -449,7 +447,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
449
447
  Args:
450
448
  graph: Graph to build the model from.
451
449
  append2output: Nodes to append to model's output.
452
- fw_info: Information about the specific framework of the model that is built.
453
450
  return_float_outputs: Whether the model returns float tensors or not.
454
451
  wrapper: A function wrapper Pytorch Layers.
455
452
  get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
@@ -457,7 +454,6 @@ class PyTorchModelBuilder(BaseModelBuilder):
457
454
 
458
455
  super().__init__(graph,
459
456
  append2output,
460
- fw_info,
461
457
  return_float_outputs)
462
458
 
463
459
  self.wrapper = wrapper
@@ -21,7 +21,6 @@ from model_compression_toolkit.core.common import BaseNode
21
21
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
22
22
  from model_compression_toolkit.core.pytorch.back2framework.quantization_wrapper.wrapper_quantize_config import \
23
23
  WrapperQuantizeConfig
24
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
25
24
  from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor
26
25
 
27
26
 
@@ -93,7 +92,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
93
92
  self.layer = n.type(**framework_attr)
94
93
  self.layer.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False)
95
94
 
96
- def _quantize_weights(self, n:BaseNode):
95
+ def _quantize_weights(self, n: BaseNode):
97
96
  """
98
97
  Quantize node's weights and load them as the layer's weights.
99
98
 
@@ -104,7 +103,7 @@ class QuantizedLayerWrapper(torch.nn.Module):
104
103
  None.
105
104
  """
106
105
 
107
- self.weight_attrs = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)
106
+ self.weight_attrs = [n.kernel_attr]
108
107
 
109
108
  # float_weights is a list of weights for each attribute that we want to quantize.
110
109
  float_weights = [n.get_weights_by_keys(attr) for attr in self.weight_attrs]
@@ -23,7 +23,6 @@ from model_compression_toolkit.core.common import BaseNode
23
23
  from model_compression_toolkit.core.common.user_info import UserInformation
24
24
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
25
25
  PytorchModel
26
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
27
26
 
28
27
 
29
28
  class QuantizedPyTorchModel(PytorchModel):
@@ -70,20 +69,17 @@ class QuantizedPyTorchModelBuilder(PyTorchModelBuilder):
70
69
  def __init__(self,
71
70
  graph: common.Graph,
72
71
  append2output=None,
73
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
74
72
  return_float_outputs: bool = False):
75
73
  """
76
74
 
77
75
  Args:
78
76
  graph: Graph to build the model from.
79
77
  append2output: Nodes to append to model's output.
80
- fw_info: Information about the specific framework of the model that is built.
81
78
  return_float_outputs: Whether the model returns float tensors or not.
82
79
  """
83
80
 
84
81
  super().__init__(graph,
85
82
  append2output,
86
- fw_info,
87
83
  return_float_outputs)
88
84
 
89
85
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
16
- from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
15
+ from typing import Any
16
+ from functools import wraps
17
+
18
+ from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU, SiLU
19
+ from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu, silu
17
20
  from torch.nn import Conv2d, ConvTranspose2d, Linear
18
21
  from torch import sigmoid
19
22
 
20
- from model_compression_toolkit.defaultdict import DefaultDict
21
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo, DEFAULT_KERNEL_ATTRIBUTES
23
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo, set_fw_info, ChannelAxisMapping
22
24
  from mct_quantizers import QuantizationMethod
23
25
  from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
24
26
  from model_compression_toolkit.core.pytorch.constants import KERNEL
@@ -26,73 +28,97 @@ from model_compression_toolkit.core.pytorch.quantizer.fake_quant_builder import
26
28
  symmetric_quantization, uniform_quantization
27
29
  from model_compression_toolkit.core.pytorch.quantizer.lut_fake_quant import activation_lut_kmean_quantizer
28
30
 
29
- """
30
- Map each layer to a list of its' weights attributes that should get quantized.
31
- If a layer that is not listed here is queried, [None] is returned.
32
- """
33
- KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
34
- ConvTranspose2d: [KERNEL],
35
- Linear: [KERNEL]},
36
- DEFAULT_KERNEL_ATTRIBUTES)
37
-
38
- """
39
- Map a layer to its kernel's output and input channels indices.
40
- Map's values are tuples of (output_channel_index, input_channel_index).
41
- Default value is returned for layers that are not included.
42
- """
43
- DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: (0, 1),
44
- Linear: (0, 1),
45
- ConvTranspose2d: (1, 0)},
46
- (None, None))
47
-
48
- """
49
- Map a layer to its output channel axis.
50
- Where axis=-1 is the last axis
51
- """
52
- DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: 1,
53
- Linear: -1,
54
- ConvTranspose2d: 1},
55
- 1)
56
-
57
-
58
- """
59
- Map from an activation function to its min/max output values (if known).
60
- The values are used for tensor min/max values initialization.
61
- """
62
- ACTIVATION2MINMAX = {} # should be an empty dict in Pytorch
63
-
64
- """
65
- Map from an Pytorch module to its min/max output values (if known).
66
- The values are used for tensor min/max values initialization.
67
- """
68
- LAYER2MINMAX = {Softmax: (0, SOFTMAX_THRESHOLD),
69
- softmax: (0, SOFTMAX_THRESHOLD),
70
- Sigmoid: (0, 1),
71
- sigmoid: (0, 1),
72
- Hardsigmoid: (0, 1),
73
- hardsigmoid: (0, 1),
74
- ReLU: (0, None),
75
- relu: (0, None),
76
- ReLU6: (0, None),
77
- relu6: (0, None),
78
- GELU: (-0.17, None),
79
- gelu: (-0.17, None),
80
- SELU: (-1.76, None),
81
- selu: (-1.76, None),
82
- }
83
-
84
- """
85
- Mapping from a QuantizationMethod to an activation quantizer function.
86
- """
87
- ACTIVATION_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
88
- QuantizationMethod.SYMMETRIC: symmetric_quantization,
89
- QuantizationMethod.UNIFORM: uniform_quantization,
90
- QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
91
-
92
-
93
- DEFAULT_PYTORCH_INFO = FrameworkInfo(ACTIVATION_QUANTIZER_MAPPING,
94
- DEFAULT_CHANNEL_AXIS_DICT,
95
- ACTIVATION2MINMAX,
96
- LAYER2MINMAX,
97
- KERNEL_ATTRIBUTES,
98
- DEFAULT_OUT_CHANNEL_AXIS_DICT)
31
+
32
+ class PyTorchInfo(FrameworkInfo):
33
+ """
34
+ Extra field defined to handle Activation layer functions:
35
+ """
36
+
37
+ """
38
+ Map each layer to it's weight attribute that should get quantized.
39
+ If a layer that is not listed here is queried, None is returned.
40
+ """
41
+ kernel_ops_attribute_mapping = {Conv2d: KERNEL,
42
+ ConvTranspose2d: KERNEL,
43
+ Linear: KERNEL}
44
+
45
+ """
46
+ Map a layer to its kernel's output and input channels indices.
47
+ Map's values are tuples of (output_channel_index, input_channel_index).
48
+ Default value is returned for layers that are not included.
49
+ """
50
+ kernel_channels_mapping = {Conv2d: ChannelAxisMapping(0, 1),
51
+ Linear: ChannelAxisMapping(0, 1),
52
+ ConvTranspose2d: ChannelAxisMapping(1, 0)}
53
+
54
+ """
55
+ Map a layer to its output channel axis.
56
+ Where axis=-1 is the last axis
57
+ """
58
+ out_channel_axis_mapping = {Conv2d: 1,
59
+ Linear: -1,
60
+ ConvTranspose2d: 1}
61
+
62
+ """
63
+ Map from an Pytorch module to its min/max output values (if known).
64
+ The values are used for tensor min/max values initialization.
65
+ """
66
+ _layer_min_max_mapping = {Softmax: (0, SOFTMAX_THRESHOLD),
67
+ softmax: (0, SOFTMAX_THRESHOLD),
68
+ Sigmoid: (0, 1),
69
+ sigmoid: (0, 1),
70
+ Hardsigmoid: (0, 1),
71
+ hardsigmoid: (0, 1),
72
+ ReLU: (0, None),
73
+ relu: (0, None),
74
+ ReLU6: (0, None),
75
+ relu6: (0, None),
76
+ GELU: (-0.17, None),
77
+ gelu: (-0.17, None),
78
+ SELU: (-1.76, None),
79
+ selu: (-1.76, None),
80
+ silu: (-0.279, None),
81
+ SiLU: (-0.279, None),
82
+ }
83
+
84
+ """
85
+ Mapping from a QuantizationMethod to an activation quantizer function.
86
+ """
87
+ activation_quantizer_mapping = {QuantizationMethod.POWER_OF_TWO: power_of_two_quantization,
88
+ QuantizationMethod.SYMMETRIC: symmetric_quantization,
89
+ QuantizationMethod.UNIFORM: uniform_quantization,
90
+ QuantizationMethod.LUT_POT_QUANTIZER: activation_lut_kmean_quantizer}
91
+
92
+ @classmethod
93
+ def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
94
+ """
95
+ Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
96
+ Args:
97
+ node_type: A node type.
98
+
99
+ Returns:
100
+ Node's channels mapping.
101
+
102
+ """
103
+ return cls.kernel_channels_mapping.get(node_type, cls._default_channel_mapping)
104
+
105
+ @classmethod
106
+ def get_out_channel_axis(cls, node_type: Any):
107
+ """
108
+ Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
109
+ Args:
110
+ node_type: A node type.
111
+
112
+ Returns:
113
+ Node's output channel axis.
114
+
115
+ """
116
+ return cls.out_channel_axis_mapping.get(node_type, 1)
117
+
118
+
119
+ def set_pytorch_info(func):
120
+ @wraps(func)
121
+ def wrapper(*args, **kwargs):
122
+ set_fw_info(PyTorchInfo)
123
+ return func(*args, **kwargs)
124
+ return wrapper
@@ -21,19 +21,18 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
21
21
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
22
22
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
23
23
  from model_compression_toolkit.core.pytorch.constants import IN_CHANNELS, OUT_CHANNELS, KERNEL_SIZE, KERNEL, BIAS
24
- from model_compression_toolkit.core.common import FrameworkInfo
24
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
25
25
 
26
26
 
27
27
  class FunctionalConvSubstitution(common.BaseSubstitution):
28
28
  """
29
29
  Substitute functional convolutions with Layers
30
30
  """
31
- def __init__(self, fw_info: FrameworkInfo):
31
+ def __init__(self):
32
32
  """
33
33
  Matches a functional conv node
34
34
  """
35
35
  func_node = NodeOperationMatcher(conv2d) | NodeOperationMatcher(conv_transpose2d)
36
- self.fw_info = fw_info
37
36
  super().__init__(matcher_instance=func_node)
38
37
 
39
38
  def substitute(self,
@@ -56,7 +55,7 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
56
55
  else:
57
56
  Logger.critical(f'Substitution filter mismatch. Layer {func_node.type}. Must be {type(Conv2d)} or {type(ConvTranspose2d)}.') # pragma: no cover
58
57
 
59
- out_channel_index, in_channel_index = self.fw_info.kernel_channels_mapping.get(new_layer)
58
+ out_channel_index, in_channel_index = get_fw_info().get_kernel_channels(new_layer)
60
59
 
61
60
  # Create new node of layer convolution
62
61
  if 1 not in func_node.weights:
@@ -46,17 +46,15 @@ class ScaleEqualization(BaseScaleEqualization):
46
46
  """
47
47
 
48
48
  def __init__(self,
49
- quant_config: QuantizationConfig,
50
- fw_info: FrameworkInfo):
49
+ quant_config: QuantizationConfig):
51
50
  """
52
51
  Initialize a ScaleEqualization object.
53
52
  Args:
54
53
  quant_config: Quantization configuration.
55
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
56
54
  groups of layers by how they should be quantized, etc.)
57
55
  """
58
56
 
59
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER,
57
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER,
60
58
  kernel_str=KERNEL, bias_str=BIAS)
61
59
 
62
60
 
@@ -66,15 +64,13 @@ class ScaleEqualizationWithPad(BaseScaleEqualization):
66
64
  """
67
65
 
68
66
  def __init__(self,
69
- quant_config: QuantizationConfig,
70
- fw_info: FrameworkInfo):
67
+ quant_config: QuantizationConfig):
71
68
  """
72
69
  Initialize a ScaleEqualization object.
73
70
  Args:
74
71
  quant_config: Quantization configuration.
75
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
76
72
  groups of layers by how they should be quantized, etc.)
77
73
  """
78
74
 
79
- super().__init__(quant_config=quant_config, fw_info=fw_info, matcher_instance=MATCHER_WITH_PAD,
75
+ super().__init__(quant_config=quant_config, matcher_instance=MATCHER_WITH_PAD,
80
76
  kernel_str=KERNEL, bias_str=BIAS)
@@ -214,15 +214,13 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
214
214
 
215
215
 
216
216
  def pytorch_apply_shift_negative_correction(graph: Graph,
217
- core_config: CoreConfig,
218
- fw_info: FrameworkInfo) -> Graph:
217
+ core_config: CoreConfig) -> Graph:
219
218
  """
220
219
  Apply shift negative correction (SNC) on a graph built from a Pytorch model.
221
220
 
222
221
  Args:
223
222
  graph: Graph to apply SNC on.
224
223
  core_config: Quantization configuration.
225
- fw_info: FrameworkInfo object with information about the specific framework's module.
226
224
 
227
225
  Returns:
228
226
  Graph after SNC.
@@ -230,7 +228,6 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
230
228
  snc_node, linear_node, bypass_node, pad_node = shift_negative_activation_node_matchers()
231
229
  return apply_shift_negative_correction(graph,
232
230
  core_config,
233
- fw_info,
234
231
  snc_node,
235
232
  linear_node,
236
233
  bypass_node,
@@ -23,7 +23,6 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
23
23
  from model_compression_toolkit.core.common import Graph
24
24
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
25
25
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
26
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
27
26
  from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
28
27
  HessianScoresCalculatorPytorch
29
28
  from model_compression_toolkit.logger import Logger
@@ -92,22 +91,14 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
92
91
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
93
92
 
94
93
  # Check if the target node's layer type is supported.
95
- if not DEFAULT_PYTORCH_INFO.is_kernel_op(ipt_node.type):
94
+ if not ipt_node.is_kernel_op:
96
95
  Logger.critical(f"Hessian information with respect to weights is not supported for "
97
96
  f"{ipt_node.type} layers.") # pragma: no cover
98
97
 
99
- # Get the weight attributes for the target node type
100
- weights_attributes = DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(ipt_node.type)
101
-
102
- # Get the weight tensor for the target node
103
- if len(weights_attributes) != 1: # pragma: no cover
104
- Logger.critical(f"Currently, Hessian scores with respect to weights are supported only for nodes with a "
105
- f"single weight attribute. {len(weights_attributes)} attributes found.")
106
-
107
- weights_tensor = getattr(getattr(model, ipt_node.name), weights_attributes[0])
98
+ weights_tensor = getattr(getattr(model, ipt_node.name), ipt_node.kernel_attr)
108
99
 
109
100
  # Get the output channel index
110
- output_channel_axis, _ = DEFAULT_PYTORCH_INFO.kernel_channels_mapping.get(ipt_node.type)
101
+ output_channel_axis = ipt_node.channel_axis.output
111
102
  shape_channel_axis = [i for i in range(len(weights_tensor.shape))]
112
103
  if self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
113
104
  shape_channel_axis.remove(output_channel_axis)