mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  92. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  93. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  94. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  95. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  96. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  97. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  98. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  99. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  100. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  101. model_compression_toolkit/gptq/runner.py +1 -7
  102. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  103. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  104. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  105. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  106. model_compression_toolkit/ptq/runner.py +1 -4
  107. model_compression_toolkit/qat/common/qat_config.py +2 -6
  108. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  109. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  110. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  111. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  112. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  113. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  114. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  115. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  116. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  117. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  118. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
  119. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
  120. {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,11 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import Dict, Any, Tuple, List, Type, Union
17
+ from typing import Dict, Any, Tuple, List, Type, Union, NamedTuple
18
18
 
19
19
  import numpy as np
20
20
 
21
+ from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
21
22
  from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
22
23
  ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
23
24
  from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
@@ -34,11 +35,21 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
34
35
  WeightAttrT = Union[str, int]
35
36
 
36
37
 
38
+ class NodeFrameworkInfo(NamedTuple):
39
+ """
40
+ Node's specific framework information.
41
+ """
42
+ channel_axis: ChannelAxisMapping
43
+ out_channel_axis: int
44
+ minmax: Tuple[float, float]
45
+ kernel_attr: str
46
+ is_kernel_op: bool
47
+
48
+
37
49
  class BaseNode:
38
50
  """
39
51
  Class to represent a node in a graph that represents the model.
40
52
  """
41
-
42
53
  def __init__(self,
43
54
  name: str,
44
55
  framework_attr: Dict[str, Any],
@@ -88,6 +99,78 @@ class BaseNode:
88
99
  self.prior_info = None
89
100
  self.has_activation = has_activation
90
101
  self.is_custom = is_custom
102
+ self.node_fw_info = self._get_fw_node_attrs(layer_class, framework_attr)
103
+
104
+ def _get_fw_node_attrs(self, node_type, framework_attr):
105
+ fw_info = get_fw_info()
106
+ return None if fw_info is None else NodeFrameworkInfo(
107
+ fw_info.get_kernel_channels(node_type),
108
+ fw_info.get_out_channel_axis(node_type),
109
+ fw_info.get_layer_min_max(node_type, framework_attr),
110
+ fw_info.get_kernel_op_attribute(node_type),
111
+ fw_info.is_kernel_op(node_type)
112
+ )
113
+
114
+ def _assert_fw_info_exists(self):
115
+ """
116
+ Verify NodeFrameworkInfo was initialized.
117
+ """
118
+ assert self.node_fw_info is not None, f"NodeFrameworkInfo not initialized for node {self.name}" # pragma: no cover
119
+
120
+ @property
121
+ def channel_axis(self) -> ChannelAxisMapping:
122
+ """
123
+ Extract channels axis from node's NodeFrameworkInfo.
124
+
125
+ Returns:
126
+ Channels axis named tuple.
127
+ """
128
+ self._assert_fw_info_exists()
129
+ return self.node_fw_info.channel_axis
130
+
131
+ @property
132
+ def out_channel_axis(self) -> int:
133
+ """
134
+ Extract output channel axis from node's NodeFrameworkInfo.
135
+
136
+ Returns:
137
+ Output channel axis.
138
+ """
139
+ self._assert_fw_info_exists()
140
+ return self.node_fw_info.out_channel_axis
141
+
142
+ @property
143
+ def minmax(self) -> Tuple[float, float]:
144
+ """
145
+ Extract expected min-max activation values from node's NodeFrameworkInfo.
146
+
147
+ Returns:
148
+ A tuple of min-max values.
149
+ """
150
+ self._assert_fw_info_exists()
151
+ return self.node_fw_info.minmax
152
+
153
+ @property
154
+ def kernel_attr(self) -> str:
155
+ """
156
+ Extract kernel name from node's NodeFrameworkInfo.
157
+
158
+ Returns:
159
+ Kernel name.
160
+ """
161
+ self._assert_fw_info_exists()
162
+ return self.node_fw_info.kernel_attr
163
+
164
+ @property
165
+ def is_kernel_op(self) -> bool:
166
+ """
167
+ Check if kernel exists for the node.
168
+
169
+ Returns:
170
+ Whether the node has a kernel or not.
171
+ """
172
+ self._assert_fw_info_exists()
173
+ return self.node_fw_info.is_kernel_op
91
174
 
92
175
  @property
93
176
  def type(self):
@@ -298,14 +381,11 @@ class BaseNode:
298
381
 
299
382
  return input_tensors
300
383
 
301
- def get_num_parameters(self, fw_info) -> Tuple[int,int]:
384
+ def get_num_parameters(self) -> Tuple[int,int]:
302
385
  """
303
386
  Compute the number of parameters the node holds.
304
387
  It returns a tuple: Number of quantized parameters, number of float parameters.
305
388
 
306
- Args:
307
- fw_info: Framework info to decide which attributes should be quantized.
308
-
309
389
  Returns:
310
390
  A tuple of (Number of quantized parameters, number of float parameters).
311
391
 
@@ -314,11 +394,10 @@ class BaseNode:
314
394
 
315
395
  q_node_num_params = 0
316
396
 
317
- for attr in fw_info.get_kernel_op_attributes(self.type):
318
- if attr is not None:
319
- w = self.get_weights_by_keys(attr)
320
- if w is not None:
321
- q_node_num_params += w.flatten().shape[0]
397
+ if self.kernel_attr is not None:
398
+ w = self.get_weights_by_keys(self.kernel_attr)
399
+ if w is not None:
400
+ q_node_num_params += w.flatten().shape[0]
322
401
 
323
402
  f_node_num_params = total_node_params - q_node_num_params
324
403
 
@@ -326,22 +405,19 @@ class BaseNode:
326
405
  assert int(f_node_num_params) == f_node_num_params
327
406
  return int(q_node_num_params), int(f_node_num_params)
328
407
 
329
- def get_memory_bytes(self, fw_info) -> float:
408
+ def get_memory_bytes(self) -> float:
330
409
  """
331
410
  Compute the number of bytes the node's memory requires.
332
411
 
333
- Args:
334
- fw_info: Framework info to decide which attributes should be quantized.
335
-
336
412
  Returns: Number of bytes the node's memory requires.
337
413
 
338
414
  """
339
415
  # TODO: this method is used for tensorboard only. If we want to enable logging of other attributes memory
340
416
  # then it needs to be modified. But, it might be better to remove this method from the BaseNode completely.
341
- kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
417
+ kernel_attr = self.kernel_attr
342
418
  if kernel_attr is None:
343
419
  return 0
344
- q_params, f_params = self.get_num_parameters(fw_info)
420
+ q_params, f_params = self.get_num_parameters()
345
421
  if self.final_weights_quantization_cfg is None: # float coefficients
346
422
  memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
347
423
  else:
@@ -351,15 +427,12 @@ class BaseNode:
351
427
 
352
428
  return memory
353
429
 
354
- def get_unified_weights_candidates_dict(self, fw_info) -> Dict[str, Any]:
430
+ def get_unified_weights_candidates_dict(self) -> Dict[str, Any]:
355
431
  """
356
432
  In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration.
357
433
  In order to display a single view of a node (for example, for logging in TensorBoard) we need a way
358
434
  to create a single dictionary from all candidates.
359
- This method is aimed to build such an unified dictionary for a node.
360
-
361
- Args:
362
- fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
435
+ This method is aimed to build such a unified dictionary for a node.
363
436
 
364
437
  Returns: A dictionary containing information from node's weight quantization configuration candidates.
365
438
 
@@ -369,7 +442,7 @@ class BaseNode:
369
442
  # We assume that only the kernel attribute have more than one candidate, since we only allow to
370
443
  # quantize the kernel using mixed precision
371
444
  # TODO: need to modify if we want to present a unified config for other attributes
372
- kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
445
+ kernel_attr = self.kernel_attr
373
446
  if kernel_attr is None:
374
447
  # This node doesn't have a kernel attribute
375
448
  return {}
@@ -437,20 +510,13 @@ class BaseNode:
437
510
  candidates = self.get_all_weights_attr_candidates(attr)
438
511
  return all(candidate == candidates[0] for candidate in candidates[1:])
439
512
 
440
- def has_kernel_weight_to_quantize(self, fw_info):
513
+ def has_kernel_weight_to_quantize(self):
441
514
  """
442
- Checks whether the node has kernel attribute that need to be quantized according to the framework info.
515
+ Checks whether the node has kernel attribute that need to be quantized according to the node's framework info.
443
516
 
444
- Args:
445
- fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
446
-
447
- Returns: Whether the node has weights that need to be quantized.
517
+ Returns: Whether the node's kernel need to be quantized.
448
518
  """
449
- attrs = fw_info.get_kernel_op_attributes(self.type)
450
- for attr in attrs:
451
- if attr and self.get_weights_by_keys(attr) is not None:
452
- return True
453
- return False
519
+ return self.kernel_attr and self.get_weights_by_keys(self.kernel_attr) is not None
454
520
 
455
521
  def has_any_weight_attr_to_quantize(self) -> bool:
456
522
  """
@@ -724,7 +790,7 @@ class BaseNode:
724
790
  Logger.critical(f"SIMD is expected to be a non-positive integer but found: {_simd}")
725
791
  return _simd
726
792
 
727
- def sort_node_candidates(self, fw_info):
793
+ def sort_node_candidates(self):
728
794
  """
729
795
  Sorts the node candidates.
730
796
  We assume that the candidates are ordered in the following way (for mixed precision purposes):
@@ -733,16 +799,11 @@ class BaseNode:
733
799
  - If the node doesn't have a kernel we only consider the candidate activation number of bits to sort
734
800
  the candidates in descending order.
735
801
  The operation is done inplace.
736
-
737
- Args:
738
- fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
739
-
740
802
  """
741
803
  if self.candidates_quantization_cfg is not None:
742
- kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
743
- if kernel_attr is not None:
804
+ if self.kernel_attr is not None:
744
805
  self.candidates_quantization_cfg.sort(
745
- key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
806
+ key=lambda c: (c.weights_quantization_cfg.get_attr_config(self.kernel_attr).weights_n_bits,
746
807
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
747
808
  else:
748
809
  self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits,
@@ -1,6 +1,21 @@
1
+ # Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
1
16
  from typing import Dict, Any, Tuple, Type, List, Union
2
17
 
3
- from model_compression_toolkit.verify_packages import FOUND_TF
18
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
4
19
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
5
20
  import numpy as np
6
21
 
@@ -45,6 +60,7 @@ class FunctionalNode(BaseNode):
45
60
  inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
46
61
  has_activation: Whether the node has activations that we might want to quantize.
47
62
  tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
63
+
48
64
  """
49
65
 
50
66
  super().__init__(name,
@@ -63,6 +79,7 @@ class FunctionalNode(BaseNode):
63
79
  self.op_call_args = list(op_call_args)
64
80
  self.functional_op = functional_op
65
81
  self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
82
+ self.node_fw_info = self._get_fw_node_attrs(functional_op, framework_attr)
66
83
 
67
84
  @property
68
85
  def type(self):
@@ -15,14 +15,13 @@
15
15
  import abc
16
16
  import uuid
17
17
 
18
- from model_compression_toolkit.core import FrameworkInfo
19
18
  from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
20
19
  VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
21
- from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
22
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
23
21
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
24
22
  CandidateNodeQuantizationConfig
25
23
  from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
24
+ from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTE
26
25
 
27
26
 
28
27
  class VirtualNode(BaseNode, abc.ABC):
@@ -128,28 +127,23 @@ class VirtualActivationWeightsNode(VirtualNode):
128
127
 
129
128
  def __init__(self,
130
129
  act_node: BaseNode,
131
- weights_node: BaseNode,
132
- fw_info: FrameworkInfo):
130
+ weights_node: BaseNode):
133
131
  """
134
132
  Init a VirtualActivationWeightsNode object.
135
133
 
136
134
  Args:
137
135
  act_node: The original activation node.
138
136
  weights_node: The original weights node.
139
- fw_info: A FrameworkInfo object with framework specific information.
140
137
  """
141
138
  # Validate weights node
142
- kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
143
- assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, f'Expected exactly one kernel attr, {kernel_attrs}'
144
- kernel_attr = kernel_attrs[0]
145
139
  conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
146
- if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
140
+ if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(weights_node.kernel_attr):
147
141
  raise NotImplementedError(f'Only kernel weight can be configurable. Got configurable {conf_weights}.')
148
142
 
149
143
  weights = weights_node.weights.copy()
150
144
  act_node_w_rename = {}
151
145
  if act_node.weights:
152
- if fw_info.get_kernel_op_attributes(act_node) != DEFAULT_KERNEL_ATTRIBUTES:
146
+ if act_node.kernel_attr != DEFAULT_KERNEL_ATTRIBUTE:
153
147
  raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
154
148
  f'VirtualActivationWeightsNode.')
155
149
  if act_node.has_any_configurable_weight():
@@ -157,7 +151,7 @@ class VirtualActivationWeightsNode(VirtualNode):
157
151
  'VirtualActivationWeightsNode.')
158
152
  # combine weights from activation and weights
159
153
  for w_id, w in act_node.weights.items():
160
- if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):
154
+ if w_id not in weights and not (isinstance(w_id, str) and weights_node.kernel_attr in w_id):
161
155
  weights[w_id] = w
162
156
  continue
163
157
  # if same identifier is used as in weight nodes (or contains the kernel substring), generate a new
@@ -185,7 +179,7 @@ class VirtualActivationWeightsNode(VirtualNode):
185
179
  self.original_weights_node = weights_node
186
180
 
187
181
  v_candidates = []
188
- weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
182
+ weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(weights_node.kernel_attr)
189
183
  for c_a in act_node.candidates_quantization_cfg:
190
184
  for c_w in weights_candidates_quantization_cfg:
191
185
  composed_candidate = CandidateNodeQuantizationConfig(activation_quantization_cfg=c_a.activation_quantization_cfg,
@@ -203,7 +197,7 @@ class VirtualActivationWeightsNode(VirtualNode):
203
197
  v_candidates.append(composed_candidate)
204
198
 
205
199
  # sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
206
- v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
200
+ v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(weights_node.kernel_attr).weights_n_bits,
207
201
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
208
202
 
209
203
  self.candidates_quantization_cfg = v_candidates
@@ -37,20 +37,18 @@ def set_bit_widths(mixed_precision_enable: bool,
37
37
  """
38
38
  if mixed_precision_enable:
39
39
  assert all([len(n.candidates_quantization_cfg) > 0
40
- for n in graph.get_configurable_sorted_nodes(graph.fw_info)]), \
40
+ for n in graph.get_configurable_sorted_nodes()]), \
41
41
  "All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"
42
42
 
43
43
  # Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
44
- sorted_nodes_names = graph.get_configurable_sorted_nodes_names(graph.fw_info)
44
+ sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
45
45
 
46
46
  for node in graph.nodes: # set a specific node qc for each node final qc
47
47
  # If it's reused, take the configuration that the base node has
48
48
  node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
49
49
  if node_name in sorted_nodes_names: # only configurable nodes are in this list
50
50
  node_index_in_graph = sorted_nodes_names.index(node_name)
51
- _set_node_final_qc(bit_widths_config[node_index_in_graph],
52
- node,
53
- graph.fw_info)
51
+ _set_node_final_qc(bit_widths_config[node_index_in_graph], node)
54
52
  else:
55
53
  if node.is_activation_quantization_enabled():
56
54
  # If we are here, this means that we are in weights-only mixed-precision
@@ -83,8 +81,7 @@ def set_bit_widths(mixed_precision_enable: bool,
83
81
 
84
82
 
85
83
  def _get_node_qc_by_bit_widths(node: BaseNode,
86
- node_bit_width_cfg: int,
87
- fw_info) -> Any:
84
+ node_bit_width_cfg: int) -> Any:
88
85
  """
89
86
  Get the node's quantization configuration that
90
87
  matches to the bit width index as in the MP configuration bit_width_cfg.
@@ -93,21 +90,18 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
93
90
  Args:
94
91
  node: Node to get its quantization configuration candidate.
95
92
  node_bit_width_cfg: Configuration which determines the node's desired bit width.
96
- fw_info: Information relevant to a specific framework about how layers should be quantized.
97
93
 
98
94
  Returns:
99
95
  Node quantization configuration if it was found, or None otherwise.
100
96
  """
101
97
  # only the weights kernel attribute is quantized in weights mixed precision at the moment
102
- kernel_attr = fw_info.get_kernel_op_attributes(node.type)
103
-
104
98
  if node.is_activation_quantization_enabled():
105
99
  qc = node.candidates_quantization_cfg[node_bit_width_cfg]
106
100
 
107
101
  return qc
108
102
 
109
- elif kernel_attr is not None:
110
- if node.is_weights_quantization_enabled(kernel_attr[0]):
103
+ elif node.kernel_attr is not None:
104
+ if node.is_weights_quantization_enabled(node.kernel_attr):
111
105
  qc = node.candidates_quantization_cfg[node_bit_width_cfg]
112
106
 
113
107
  return qc
@@ -116,8 +110,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
116
110
 
117
111
 
118
112
  def _set_node_final_qc(node_bit_width_cfg: int,
119
- node: BaseNode,
120
- fw_info):
113
+ node: BaseNode):
121
114
  """
122
115
  Get the node's quantization configuration that
123
116
  matches to the bit width index as in the MP configuration bit_width_cfg, and use it to finalize the node's
@@ -127,12 +120,9 @@ def _set_node_final_qc(node_bit_width_cfg: int,
127
120
  Args:
128
121
  node_bit_width_cfg: Configuration which determines the node's desired bit width.
129
122
  node: Node to set its node quantization configuration.
130
- fw_info: Information relevant to a specific framework about how layers should be quantized.
131
123
 
132
124
  """
133
- node_qc = _get_node_qc_by_bit_widths(node,
134
- node_bit_width_cfg,
135
- fw_info)
125
+ node_qc = _get_node_qc_by_bit_widths(node, node_bit_width_cfg)
136
126
 
137
127
  if node_qc is None:
138
128
  Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover
@@ -22,7 +22,6 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
22
22
 
23
23
  def filter_candidates_for_mixed_precision(graph: Graph,
24
24
  target_resource_utilization: ResourceUtilization,
25
- fw_info: FrameworkInfo,
26
25
  fqc: FrameworkQuantizationCapabilities):
27
26
  """
28
27
  Filters out candidates in case of mixed precision search for only weights or activation compression.
@@ -35,7 +34,6 @@ def filter_candidates_for_mixed_precision(graph: Graph,
35
34
  Args:
36
35
  graph: A graph representation of the model to be quantized.
37
36
  target_resource_utilization: The resource utilization of the target device.
38
- fw_info: fw_info: Information needed for quantization about the specific framework.
39
37
  fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
40
38
 
41
39
  """
@@ -59,11 +57,10 @@ def filter_candidates_for_mixed_precision(graph: Graph,
59
57
  elif tru.activation_restricted() and not tru.weight_restricted():
60
58
  # Running mixed precision for activation compression only -
61
59
  # filter out candidates weights only configurable node
62
- weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(fw_info)]
60
+ weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
63
61
  for n in weight_configurable_nodes:
64
- kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
65
- base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
62
+ base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[n.kernel_attr].weights_n_bits
66
63
  filtered_conf = [c for c in n.candidates_quantization_cfg if
67
- c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
68
- c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
64
+ c.weights_quantization_cfg.get_attr_config(n.kernel_attr).enable_weights_quantization and
65
+ c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits == base_cfg_nbits]
69
66
  n.candidates_quantization_cfg = filtered_conf
@@ -30,11 +30,10 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
30
30
  class MixedPrecisionRUHelper:
31
31
  """ Helper class for resource utilization computations for mixed precision optimization. """
32
32
 
33
- def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImplementation):
33
+ def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
34
34
  self.graph = graph
35
- self.fw_info = fw_info
36
35
  self.fw_impl = fw_impl
37
- self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
36
+ self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
38
37
 
39
38
  def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
40
39
  """
@@ -35,7 +35,6 @@ class BitWidthSearchMethod(Enum):
35
35
 
36
36
 
37
37
  def search_bit_width(graph: Graph,
38
- fw_info: FrameworkInfo,
39
38
  fw_impl: FrameworkImplementation,
40
39
  target_resource_utilization: ResourceUtilization,
41
40
  mp_config: MixedPrecisionQuantizationConfig,
@@ -52,7 +51,6 @@ def search_bit_width(graph: Graph,
52
51
 
53
52
  Args:
54
53
  graph: Graph to search a MP configuration for.
55
- fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
56
54
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
57
55
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
58
56
  mp_config: Mixed-precision quantization configuration.
@@ -79,7 +77,7 @@ def search_bit_width(graph: Graph,
79
77
 
80
78
  # Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
81
79
  # even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
82
- se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
80
+ se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen,
83
81
  fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
84
82
  hessian_info_service=hessian_info_service)
85
83
 
@@ -93,7 +91,6 @@ def search_bit_width(graph: Graph,
93
91
 
94
92
  # Search manager and LP are highly coupled, so LP search method was moved inside search manager.
95
93
  search_manager = MixedPrecisionSearchManager(graph,
96
- fw_info=fw_info,
97
94
  fw_impl=fw_impl,
98
95
  sensitivity_evaluator=se,
99
96
  target_resource_utilization=target_resource_utilization,
@@ -105,6 +102,6 @@ def search_bit_width(graph: Graph,
105
102
  if mp_config.refine_mp_solution:
106
103
  nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
107
104
 
108
- topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
105
+ topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes()]
109
106
  assert len(topo_bit_cfg) == len(nodes_bit_cfg)
110
107
  return topo_bit_cfg
@@ -53,7 +53,6 @@ class MixedPrecisionSearchManager:
53
53
 
54
54
  def __init__(self,
55
55
  graph: Graph,
56
- fw_info: FrameworkInfo,
57
56
  fw_impl: FrameworkImplementation,
58
57
  sensitivity_evaluator: SensitivityEvaluation,
59
58
  target_resource_utilization: ResourceUtilization,
@@ -62,14 +61,12 @@ class MixedPrecisionSearchManager:
62
61
 
63
62
  Args:
64
63
  graph: Graph to search for its MP configuration.
65
- fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
66
64
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
67
65
  sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
68
66
  a bit-width configuration for the MP model.
69
67
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
70
68
  """
71
69
 
72
- self.fw_info = fw_info
73
70
  self.fw_impl = fw_impl
74
71
 
75
72
  self.original_graph = graph
@@ -81,12 +78,12 @@ class MixedPrecisionSearchManager:
81
78
  self.target_resource_utilization = target_resource_utilization
82
79
  self.mp_config = mp_config
83
80
 
84
- self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(fw_info)
81
+ self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes()
85
82
 
86
83
  self.ru_targets = target_resource_utilization.get_restricted_targets()
87
- self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_info, fw_impl)
84
+ self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_impl)
88
85
 
89
- self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
86
+ self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config()
90
87
 
91
88
  self.config_reconstructor = None
92
89
  orig_min_config = self.min_ru_config
@@ -124,10 +124,9 @@ class ResourceUtilizationCalculator:
124
124
  unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
125
125
  unexpected_qc_nodes_error = 'Custom quantization configuration contains unexpected node names.'
126
126
 
127
- def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
127
+ def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
128
128
  self.graph = graph
129
129
  self.fw_impl = fw_impl
130
- self.fw_info = fw_info
131
130
 
132
131
  # Currently we go over the full graph even if utilization won't be requested for all nodes.
133
132
  # We could fill the cache on the fly only for requested nodes, but it's probably negligible.
@@ -544,14 +543,10 @@ class ResourceUtilizationCalculator:
544
543
  self._validate_custom_qcs(w_qc, bitwidth_mode)
545
544
 
546
545
  # check if the node has kernel
547
- kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
548
- if len(kernel_attrs) > 1: # pragma: no cover
549
- raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.')
550
- if not kernel_attrs or not kernel_attrs[0]:
546
+ if not n.kernel_attr:
551
547
  return 0
552
548
 
553
- kernel_attr = kernel_attrs[0]
554
- node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
549
+ node_mac = self.fw_impl.get_node_mac_operations(n)
555
550
  if node_mac == 0:
556
551
  return node_mac
557
552
 
@@ -559,12 +554,12 @@ class ResourceUtilizationCalculator:
559
554
  assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
560
555
  a_node = prev_nodes[0]
561
556
  if (target_criterion == TargetInclusionCriterion.AnyQuantized and
562
- not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
557
+ not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(n.kernel_attr))):
563
558
  return 0
564
559
 
565
560
  act_qc = self._extract_qc(a_node, act_qcs)
566
561
  a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
567
- w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
562
+ w_nbits = self._get_weight_nbits(n, n.kernel_attr, bitwidth_mode, w_qc)
568
563
  node_bops = a_nbits * w_nbits * node_mac
569
564
  return node_bops
570
565
 
@@ -15,7 +15,7 @@
15
15
  import copy
16
16
  from typing import Callable, Any
17
17
 
18
- from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod
18
+ from model_compression_toolkit.core import ResourceUtilization, CoreConfig, QuantizationErrorMethod
19
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
20
20
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
21
21
  ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion
@@ -27,7 +27,6 @@ def compute_resource_utilization_data(in_model: Any,
27
27
  representative_data_gen: Callable,
28
28
  core_config: CoreConfig,
29
29
  fqc: FrameworkQuantizationCapabilities,
30
- fw_info: FrameworkInfo,
31
30
  fw_impl: FrameworkImplementation) -> ResourceUtilization:
32
31
  """
33
32
  Compute Resource Utilization of a model with the default single precision quantization.
@@ -39,7 +38,6 @@ def compute_resource_utilization_data(in_model: Any,
39
38
  core_config: CoreConfig containing parameters of how the model should be quantized.
40
39
  fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
41
40
  the attached framework operator's information.
42
- fw_info: Information needed for quantization about the specific framework.
43
41
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
44
42
 
45
43
  Returns:
@@ -55,12 +53,11 @@ def compute_resource_utilization_data(in_model: Any,
55
53
  transformed_graph = graph_preparation_runner(in_model,
56
54
  representative_data_gen=representative_data_gen,
57
55
  quantization_config=core_config.quantization_config,
58
- fw_info=fw_info,
59
56
  fw_impl=fw_impl,
60
57
  fqc=fqc,
61
58
  bit_width_config=core_config.bit_width_config,
62
59
  mixed_precision_enable=False,
63
60
  running_gptq=False)
64
61
 
65
- ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
62
+ ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl)
66
63
  return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)