mct-nightly 1.1.0.6012022.post2521__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,7 @@
16
16
  from abc import ABC, abstractmethod
17
17
  from collections import namedtuple
18
18
 
19
- from model_compression_toolkit.common.graph.base_node import BaseNode
20
- from model_compression_toolkit.common.quantization import quantization_params_generation
19
+ from model_compression_toolkit.common.graph.node import Node
21
20
  from model_compression_toolkit.common.quantization.quantization_params_fn_selection import \
22
21
  get_activation_quantization_params_fn, get_weights_quantization_params_fn
23
22
 
@@ -51,7 +50,7 @@ class BaseAction(ABC):
51
50
  """
52
51
 
53
52
  @abstractmethod
54
- def apply(self, node: BaseNode, graph, fw_info):
53
+ def apply(self, node: Node, graph, fw_info):
55
54
  """
56
55
  Apply an action on the node after matching the node with a node filter.
57
56
 
@@ -82,7 +81,7 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
82
81
  """
83
82
  self.kwargs = kwargs
84
83
 
85
- def apply(self, node: BaseNode, graph, fw_info):
84
+ def apply(self, node: Node, graph, fw_info):
86
85
  """
87
86
  Change the attribute 'attr_name' in quant_config with 'attr_value'.
88
87
 
@@ -94,9 +93,10 @@ class ChangeCandidatesWeightsQuantConfigAttr(BaseAction):
94
93
  Returns:
95
94
  The node after its quant_config has been modified.
96
95
  """
97
- for nqc in node.candidates_weights_quantization_cfg:
98
- for attr_name, attr_value in self.kwargs.items():
99
- nqc.set_quant_config_attr(attr_name, attr_value)
96
+ if node.candidates_weights_quantization_cfg is not None:
97
+ for nqc in node.candidates_weights_quantization_cfg:
98
+ for attr_name, attr_value in self.kwargs.items():
99
+ nqc.set_quant_config_attr(attr_name, attr_value)
100
100
 
101
101
 
102
102
  class ChangeFinalWeightsQuantConfigAttr(BaseAction):
@@ -113,7 +113,7 @@ class ChangeFinalWeightsQuantConfigAttr(BaseAction):
113
113
  """
114
114
  self.kwargs = kwargs
115
115
 
116
- def apply(self, node: BaseNode, graph, fw_info):
116
+ def apply(self, node: Node, graph, fw_info):
117
117
  if node.final_weights_quantization_cfg is not None:
118
118
  for attr_name, attr_value in self.kwargs.items():
119
119
  node.final_weights_quantization_cfg.set_quant_config_attr(attr_name, attr_value)
@@ -134,7 +134,7 @@ class ChangeActivationQuantConfigAttr(BaseAction):
134
134
  """
135
135
  self.kwargs = kwargs
136
136
 
137
- def apply(self, node: BaseNode, graph, fw_info):
137
+ def apply(self, node: Node, graph, fw_info):
138
138
  """
139
139
  Change the attribute 'attr_name' in quant_config with 'attr_value'.
140
140
 
@@ -146,8 +146,9 @@ class ChangeActivationQuantConfigAttr(BaseAction):
146
146
  Returns:q
147
147
  The node after its quant_config has been modified.
148
148
  """
149
- for attr_name, attr_value in self.kwargs.items():
150
- node.activation_quantization_cfg.set_quant_config_attr(attr_name, attr_value)
149
+ if node.activation_quantization_cfg is not None:
150
+ for attr_name, attr_value in self.kwargs.items():
151
+ node.activation_quantization_cfg.set_quant_config_attr(attr_name, attr_value)
151
152
 
152
153
 
153
154
  class ChangeQuantizationParamFunction(BaseAction):
@@ -166,7 +167,7 @@ class ChangeQuantizationParamFunction(BaseAction):
166
167
  self.activation_quantization_params_fn = activation_quantization_params_fn
167
168
  self.weights_quantization_params_fn = weights_quantization_params_fn
168
169
 
169
- def apply(self, node: BaseNode, graph, fw_info):
170
+ def apply(self, node: Node, graph, fw_info):
170
171
  """
171
172
  Change the node's weights/activations quantization params function.
172
173
 
@@ -201,7 +202,7 @@ class ChangeActivationQuantizationMethod(BaseAction):
201
202
  """
202
203
  self.activation_quantization_method = activation_quantization_method
203
204
 
204
- def apply(self, node: BaseNode, graph, fw_info):
205
+ def apply(self, node: Node, graph, fw_info):
205
206
  """
206
207
  Change the node's activations quantization function.
207
208
 
@@ -216,12 +217,15 @@ class ChangeActivationQuantizationMethod(BaseAction):
216
217
  """
217
218
  if self.activation_quantization_method is not None:
218
219
 
220
+ out_stats_container = graph.get_out_stats_collector(node)[0] if isinstance(
221
+ graph.get_out_stats_collector(node),
222
+ list) else graph.get_out_stats_collector(
223
+ node)
224
+
219
225
  activation_quantization_params_fn = get_activation_quantization_params_fn(
220
226
  self.activation_quantization_method,
221
- node.activation_quantization_cfg.activation_threshold_method)
222
-
223
- if node.prior_info.is_output_bounded():
224
- activation_quantization_params_fn = quantization_params_generation.no_clipping_selection_min_max
227
+ node.activation_quantization_cfg.activation_threshold_method,
228
+ out_stats_container.use_min_max)
225
229
 
226
230
  node.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
227
231
  activation_quantization_fn = fw_info.activation_quantizer_mapping.get(self.activation_quantization_method)
@@ -248,7 +252,7 @@ class ChangeFinalWeightsQuantizationMethod(BaseAction):
248
252
 
249
253
  self.weights_quantization_method = weights_quantization_method
250
254
 
251
- def apply(self, node: BaseNode, graph, fw_info):
255
+ def apply(self, node: Node, graph, fw_info):
252
256
  """
253
257
  Change the node's weights quantization function.
254
258
 
@@ -292,7 +296,7 @@ class ChangeCandidtaesWeightsQuantizationMethod(BaseAction):
292
296
  """
293
297
  self.weights_quantization_method = weights_quantization_method
294
298
 
295
- def apply(self, node: BaseNode, graph, fw_info):
299
+ def apply(self, node: Node, graph, fw_info):
296
300
  """
297
301
  Change the node's weights quantization function.
298
302
 
@@ -35,13 +35,11 @@ from model_compression_toolkit.common.network_editors.actions import EditRule
35
35
  from model_compression_toolkit.common.network_editors.edit_network import edit_network_graph
36
36
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
37
37
  MixedPrecisionQuantizationConfig
38
- from model_compression_toolkit.common.quantization.quantization_params_fn_selection import \
39
- get_activation_quantization_params_fn
40
38
  from model_compression_toolkit.common.quantization.quantize_graph_weights import quantize_graph_weights
41
39
  from model_compression_toolkit.common.bias_correction.compute_bias_correction_of_graph import compute_bias_correction_of_graph
42
40
 
43
41
  from model_compression_toolkit.common.quantization.quantization_analyzer import analyzer_graph
44
- from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG, ThresholdSelectionMethod
42
+ from model_compression_toolkit.common.quantization.quantization_config import DEFAULTCONFIG
45
43
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
46
44
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_computation import \
47
45
  calculate_quantization_params
@@ -364,21 +362,6 @@ def _prepare_model_for_quantization(in_model: Any,
364
362
  if tb_w is not None:
365
363
  tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')
366
364
 
367
- #########################################
368
- # Set prior info to nodes
369
- ##########################################
370
- for node in transformed_graph.nodes:
371
- node.prior_info = fw_impl.get_node_prior_info(node=node,
372
- fw_info=fw_info)
373
-
374
-
375
- ######################################
376
- # Add quantization configurations
377
- ######################################
378
- transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
379
- quant_config=quant_config,
380
- fw_info=fw_info)
381
-
382
365
  ######################################
383
366
  # Graph marking points
384
367
  ######################################
@@ -398,7 +381,6 @@ def _prepare_model_for_quantization(in_model: Any,
398
381
  if tb_w is not None:
399
382
  tb_w.add_graph(transformed_graph, 'after_analyzer_graph')
400
383
 
401
-
402
384
  ######################################
403
385
  # Statistic collection
404
386
  ######################################
@@ -409,6 +391,12 @@ def _prepare_model_for_quantization(in_model: Any,
409
391
  for _ in tqdm(range(n_iter)):
410
392
  mi.infer(representative_data_gen())
411
393
 
394
+ ######################################
395
+ # Add quantization configurations
396
+ ######################################
397
+ transformed_graph = set_quantization_configuration_to_graph(transformed_graph,
398
+ quant_config,
399
+ fw_info)
412
400
 
413
401
  ######################################
414
402
  # Edit network according to user specific settings
@@ -469,4 +457,3 @@ def _prepare_model_for_quantization(in_model: Any,
469
457
  assert n.final_weights_quantization_cfg is None
470
458
 
471
459
  return tg_with_bias
472
-
@@ -62,7 +62,8 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
62
62
  def __init__(self,
63
63
  qc: QuantizationConfig,
64
64
  activation_quantization_fn: Callable,
65
- activation_quantization_params_fn: Callable
65
+ activation_quantization_params_fn: Callable,
66
+ activation_is_signed: bool = None
66
67
  ):
67
68
  """
68
69
 
@@ -70,10 +71,11 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
70
71
  qc: QuantizationConfig to create the node's config from.
71
72
  activation_quantization_fn: Function to use when quantizing the node's activations.
72
73
  activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
74
+ activation_is_signed: Signedness of the activation quantized range.
73
75
  """
74
-
75
76
  self.activation_quantization_fn = activation_quantization_fn
76
77
  self.activation_quantization_params_fn = activation_quantization_params_fn
78
+ self.activation_is_signed = activation_is_signed
77
79
  self.activation_quantization_params = {}
78
80
  self.activation_threshold_method = qc.activation_threshold_method
79
81
  self.activation_quantization_method = qc.activation_quantization_method
@@ -89,14 +91,6 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
89
91
  self.shift_negative_ratio = qc.shift_negative_ratio
90
92
  self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
91
93
 
92
- def generate_quantization_node(self) -> Callable:
93
- """
94
- Returns: Quantization function to use for quantizing the node's activations,
95
- with the node's quantization configuration properties.
96
- """
97
- return self.activation_quantization_fn(self.activation_n_bits,
98
- self.activation_quantization_params)
99
-
100
94
  def set_activation_quantization_fn(self, activation_quantization_fn: Callable):
101
95
  """
102
96
  Sets activation quantization function for the node.
@@ -126,7 +120,6 @@ class NodeActivationQuantizationConfig(BaseNodeNodeQuantizationConfig):
126
120
  activation_params: Dictionary that contains weight quantization params.
127
121
 
128
122
  """
129
- assert self.enable_activation_quantization
130
123
  for param_name, param_value in activation_params.items():
131
124
  self.activation_quantization_params[param_name] = param_value
132
125
 
@@ -205,7 +198,6 @@ class NodeWeightsQuantizationConfig(BaseNodeNodeQuantizationConfig):
205
198
  weights_params: Dictionary that contains weight quantization params.
206
199
 
207
200
  """
208
- assert self.enable_weights_quantization
209
201
  for param_name, param_value in weights_params.items():
210
202
  self.weights_quantization_params[param_name] = param_value
211
203
 
@@ -218,7 +210,7 @@ class NodeWeightsQuantizationConfig(BaseNodeNodeQuantizationConfig):
218
210
  Recalculated weights quantization params from the kernel and channel axis.
219
211
 
220
212
  """
221
- assert self.enable_weights_quantization
213
+
222
214
  if self.weights_quantization_params_fn is not None:
223
215
  self.set_weights_quantization_param(self.weights_quantization_params_fn(tensor_data,
224
216
  p=self.l_p_value,
@@ -21,20 +21,17 @@ from model_compression_toolkit import common
21
21
 
22
22
 
23
23
  def create_tensor2node(graph: common.Graph,
24
- node: common.BaseNode,
25
- fw_info: common.FrameworkInfo):
24
+ node: common.Node):
26
25
  """
27
26
  Force tensor creation and assignment for a node.
28
27
  Args:
29
28
  graph: Graph of the node (for retrieving the current tensor).
30
29
  node: Node to create a tensor for.
31
- fw_info: Specific framework information (for example, output channels index).
32
30
 
33
31
  """
34
32
  current_tensor = graph.get_out_stats_collector(node)
35
- is_list_nostat_collectors = isinstance(current_tensor, list) and len([sc for sc in current_tensor if not isinstance(sc, common.NoStatsCollector)]) == 0
36
- if isinstance(current_tensor, common.NoStatsCollector) or current_tensor is None or is_list_nostat_collectors:
37
- graph.set_out_stats_collector_to_node(node, common.StatsCollector(output_channel_index=fw_info.output_channel_index))
33
+ if isinstance(current_tensor, common.NoStatsContainer) or current_tensor is None:
34
+ graph.set_out_stats_collector_to_node(node, common.StatsContainer())
38
35
 
39
36
 
40
37
  def analyzer_graph(node_analyze_func: Callable,
@@ -56,7 +53,7 @@ def analyzer_graph(node_analyze_func: Callable,
56
53
  """
57
54
  nodes_sorted = topological_sort(graph)
58
55
  for n in nodes_sorted:
59
- sc = node_analyze_func(n, output_channel_index=fw_info.output_channel_index) # Get tensor for the node
56
+ t = node_analyze_func(n, fw_info) # Get tensor for the node
60
57
  # If we use bias correction, and the node has coefficients to quantize, we need to make sure
61
58
  # its previous nodes' tensors are consistent with this node.
62
59
  # TODO: factor tensor marking in case of bias correction.
@@ -64,7 +61,6 @@ def analyzer_graph(node_analyze_func: Callable,
64
61
  for ie in graph.incoming_edges(n):
65
62
  input_node = ie.source_node
66
63
  create_tensor2node(graph,
67
- input_node,
68
- fw_info)
69
- if sc is not None:
70
- graph.set_out_stats_collector_to_node(n, sc)
64
+ input_node)
65
+ if t is not None:
66
+ graph.set_out_stats_collector_to_node(n, t)
@@ -155,12 +155,12 @@ DEFAULTCONFIG = QuantizationConfig(ThresholdSelectionMethod.MSE,
155
155
  ThresholdSelectionMethod.MSE,
156
156
  QuantizationMethod.POWER_OF_TWO,
157
157
  QuantizationMethod.POWER_OF_TWO,
158
- weights_n_bits=8,
159
- activation_n_bits=8,
160
- relu_unbound_correction=False,
161
- weights_bias_correction=True,
162
- weights_per_channel_threshold=True,
163
- input_scaling=False)
158
+ 8,
159
+ 8,
160
+ False,
161
+ True,
162
+ True,
163
+ False)
164
164
 
165
165
 
166
166
 
@@ -23,7 +23,8 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
23
23
 
24
24
 
25
25
  def get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod,
26
- activation_threshold_method: ThresholdSelectionMethod) -> Callable:
26
+ activation_threshold_method: ThresholdSelectionMethod,
27
+ use_min_max: bool) -> Callable:
27
28
  """
28
29
  Generate a function for finding activation quantization threshold.
29
30
 
@@ -37,7 +38,7 @@ def get_activation_quantization_params_fn(activation_quantization_method: Quanti
37
38
  """
38
39
  if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
39
40
  # Use min/max as the threshold if we use NOCLIPPING
40
- if activation_threshold_method == ThresholdSelectionMethod.NOCLIPPING:
41
+ if use_min_max or activation_threshold_method == ThresholdSelectionMethod.NOCLIPPING:
41
42
  params_fn = quantization_params_generation.no_clipping_selection_min_max
42
43
  # Use MSE to search_methods for the optimal threshold.
43
44
  elif activation_threshold_method == ThresholdSelectionMethod.MSE:
@@ -15,13 +15,12 @@
15
15
  import numpy as np
16
16
  from typing import Tuple, Dict
17
17
 
18
- from model_compression_toolkit.common import BaseNode, Graph
19
- from model_compression_toolkit.common.constants import SIGNED
18
+ from model_compression_toolkit.common import Node, Graph
20
19
  from model_compression_toolkit.common.quantization import quantization_params_generation
21
20
 
22
21
 
23
- def get_activations_qparams(n: BaseNode,
24
- graph: Graph) -> Dict[str, float]:
22
+ def get_activations_qparams(n: Node,
23
+ graph: Graph) -> Tuple[Dict[str, float], bool]:
25
24
  """
26
25
  Compute the activations params for a given node in a graph according to a params function.
27
26
 
@@ -30,29 +29,25 @@ def get_activations_qparams(n: BaseNode,
30
29
  graph: Graph the node is in.
31
30
 
32
31
  Returns:
33
- The computed activation quantization params.
32
+ Tuple of the computed quantization params and sign for the node's activations quantization.
34
33
  """
35
-
36
34
  out_stats_container = graph.get_out_stats_collector(n)
37
35
  bins_values, bins_counts = None, None
38
36
 
39
37
  # If the statistics container collected the histogram, we start by filtering outliers using z threshold
40
38
  # filtering, and then computing the threshold based on the filtered histogram.
41
- if out_stats_container.require_collection():
39
+ if out_stats_container.collect_histogram:
42
40
  bins_values, bins_counts = out_stats_container.hc.get_histogram()
43
41
  bins_counts = quantization_params_generation.z_score_filter(n.activation_quantization_cfg.z_threshold,
44
42
  bins_values,
45
43
  bins_counts)
46
44
  min_value, max_value = out_stats_container.get_min_max_values()
47
45
 
48
- if n.prior_info.is_output_bounded():
46
+ if out_stats_container.use_min_max:
49
47
  signed = min_value < 0
50
48
  else:
51
49
  signed = np.any(bins_values < 0)
52
50
 
53
- if n.prior_info.is_output_bounded():
54
- n.activation_quantization_cfg.activation_quantization_params_fn = quantization_params_generation.no_clipping_selection_min_max
55
-
56
51
  activation_params = n.activation_quantization_cfg.activation_quantization_params_fn(bins_values,
57
52
  bins_counts,
58
53
  n.activation_quantization_cfg.l_p_value,
@@ -60,6 +55,5 @@ def get_activations_qparams(n: BaseNode,
60
55
  min_value,
61
56
  max_value,
62
57
  min_threshold=n.activation_quantization_cfg.min_threshold)
63
- activation_params.update({SIGNED: signed})
64
58
 
65
- return activation_params
59
+ return activation_params, signed
@@ -16,7 +16,7 @@ from typing import List
16
16
 
17
17
  from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
18
18
  from model_compression_toolkit.common.framework_info import FrameworkInfo
19
- from model_compression_toolkit.common import Graph, BaseNode, Logger
19
+ from model_compression_toolkit.common import Graph, Node, Logger
20
20
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_activations_computation \
21
21
  import \
22
22
  get_activations_qparams
@@ -26,7 +26,7 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
26
26
 
27
27
  def calculate_quantization_params(graph: Graph,
28
28
  fw_info: FrameworkInfo,
29
- nodes: List[BaseNode] = [],
29
+ nodes: List[Node] = [],
30
30
  specific_nodes: bool = False,
31
31
  fw_impl: FrameworkImplementation = None):
32
32
  """
@@ -48,7 +48,7 @@ def calculate_quantization_params(graph: Graph,
48
48
  """
49
49
 
50
50
  # Create a list of nodes to compute their thresholds
51
- nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()
51
+ nodes_list: List[Node] = nodes if specific_nodes else graph.nodes()
52
52
 
53
53
  for n in nodes_list: # iterate only nodes that we should compute their thresholds
54
54
 
@@ -56,23 +56,25 @@ def calculate_quantization_params(graph: Graph,
56
56
  input_channels_axis, activation_threshold_float = {}, {}, None, None, None, None
57
57
 
58
58
  if fw_info.in_kernel_ops(n): # If the node has a kernel to quantize
59
- if n.is_weights_quantization_enabled():
60
- for candidtae_qc in n.candidates_weights_quantization_cfg:
61
- output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
62
- weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL),
63
- candidtae_qc,
64
- output_channels_axis)
65
59
 
66
- candidtae_qc.set_weights_quantization_param(weights_params)
67
- candidtae_qc.weights_channels_axis = output_channels_axis
60
+ for candidtae_qc in n.candidates_weights_quantization_cfg:
61
+ output_channels_axis, _ = get_channels_axis(candidtae_qc, fw_info, n.layer_class)
62
+ weights_params = get_weights_qparams(n.get_weights_by_keys(fw_impl.constants.KERNEL),
63
+ candidtae_qc,
64
+ output_channels_axis)
68
65
 
69
- if n.is_activation_quantization_enabled(): # If node's activations should be quantized as well, we compute its
66
+ candidtae_qc.set_weights_quantization_param(weights_params)
67
+ candidtae_qc.weights_channels_axis = output_channels_axis
68
+
69
+ if n.output_quantization: # If node's activations should be quantized as well, we compute its
70
70
  # activation threshold
71
- activation_params = get_activations_qparams(n=n, graph=graph)
71
+ activation_params, activation_is_signed = get_activations_qparams(n=n,
72
+ graph=graph)
72
73
 
73
74
  elif fw_info.in_activation_ops(n): # If node has no kernel, but its activations should be quantized
74
- if n.is_activation_quantization_enabled():
75
- activation_params = get_activations_qparams(n=n, graph=graph)
75
+ if n.output_quantization:
76
+ activation_params, activation_is_signed = get_activations_qparams(n=n,
77
+ graph=graph)
76
78
  # If node should not be quantized at all
77
79
  elif fw_info.in_no_quantization_ops(n):
78
80
  pass # pragma: no cover
@@ -82,5 +84,6 @@ def calculate_quantization_params(graph: Graph,
82
84
  Logger.warning(f"Warning: unknown layer: {n.layer_class.__name__}")
83
85
 
84
86
  # Create a NodeQuantizationConfig containing all quantization params and attach it to the node
85
- if n.is_activation_quantization_enabled():
86
- n.activation_quantization_cfg.set_activation_quantization_param(activation_params)
87
+ if n.activation_quantization_cfg is not None:
88
+ n.activation_quantization_cfg.set_activation_quantization_param(activation_params)
89
+ n.activation_quantization_cfg.activation_is_signed = activation_is_signed
@@ -19,7 +19,7 @@ import copy
19
19
  from model_compression_toolkit import common
20
20
  from model_compression_toolkit.common import Logger
21
21
  from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
22
- from model_compression_toolkit.common.graph.base_node import BaseNode
22
+ from model_compression_toolkit.common.graph.node import Node
23
23
  from model_compression_toolkit.common.framework_info import FrameworkInfo
24
24
  from model_compression_toolkit.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig
25
25
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_weights_computation import \
@@ -27,7 +27,7 @@ from model_compression_toolkit.common.quantization.quantization_params_generatio
27
27
 
28
28
 
29
29
  def get_quantized_kernel_by_weights_qc(fw_info:FrameworkInfo,
30
- n:BaseNode,
30
+ n:Node,
31
31
  weights_qc: NodeWeightsQuantizationConfig,
32
32
  fw_impl: FrameworkImplementation):
33
33
  """
@@ -17,7 +17,8 @@
17
17
  import copy
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit.common import Logger, BaseNode
20
+ from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
21
+ from model_compression_toolkit.common import Logger
21
22
  from model_compression_toolkit.common.framework_info import FrameworkInfo
22
23
  from model_compression_toolkit.common.graph.base_graph import Graph
23
24
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
@@ -46,45 +47,38 @@ def set_quantization_configuration_to_graph(graph: Graph,
46
47
  """
47
48
 
48
49
  graph_with_qcs = copy.deepcopy(graph)
50
+
49
51
  for n in graph_with_qcs.nodes:
50
- set_quantization_configs_to_node(node=n,
51
- quant_config=quant_config,
52
- fw_info=fw_info)
52
+ # Set qc only when needed
53
+ quantize_node_weights = False
54
+ quantize_node_activations = False
55
+
56
+ if fw_info.in_kernel_ops(n):
57
+ quantize_node_weights = True
58
+ quantize_node_activations = n.output_quantization
59
+ elif fw_info.in_activation_ops(n):
60
+ quantize_node_activations = True
61
+
62
+ if quantize_node_activations:
63
+ # Create activation QC for this node
64
+ out_sc = graph_with_qcs.get_out_stats_collector(n)
65
+ sc = out_sc[0] if isinstance(out_sc, list) else out_sc
66
+ use_min_max = sc.use_min_max
67
+ n.activation_quantization_cfg = create_node_activation_qc(quant_config,
68
+ fw_info,
69
+ use_min_max)
70
+ if quantize_node_weights:
71
+ # Create weights QC for this node
72
+ weight_channel_axis = fw_info.kernel_channels_mapping.get(n.layer_class)[0]
73
+ n.candidates_weights_quantization_cfg = _create_node_candidates_weights_qc(quant_config,
74
+ fw_info,
75
+ weight_channel_axis)
53
76
  return graph_with_qcs
54
77
 
55
78
 
56
- def set_quantization_configs_to_node(node: BaseNode,
57
- quant_config: QuantizationConfig,
58
- fw_info: FrameworkInfo):
59
- """
60
- Create and set quantization configurations to a node (for both weights and activation).
61
-
62
- Args:
63
- node: Node to set its quantization configurations.
64
- quant_config: Quantization configuration to generate the node's configurations from.
65
- fw_info: Information needed for quantization about the specific framework.
66
-
67
- """
68
- # Create activation QC for this node
69
- node.activation_quantization_cfg = create_node_activation_qc(quant_config,
70
- fw_info)
71
-
72
- enable_activation_quantization = quant_config.enable_activation_quantization and (fw_info.in_activation_ops(node) or fw_info.in_kernel_ops(node))
73
- node.activation_quantization_cfg.enable_activation_quantization = enable_activation_quantization
74
-
75
- # Create weights QC for this node
76
- weight_channel_axis = fw_info.kernel_channels_mapping.get(node.layer_class)[0]
77
- node.candidates_weights_quantization_cfg = _create_node_candidates_weights_qc(quant_config,
78
- fw_info,
79
- weight_channel_axis)
80
-
81
- enable_weights_quantization = quant_config.enable_weights_quantization and fw_info.in_kernel_ops(node)
82
- for qc in node.candidates_weights_quantization_cfg:
83
- qc.enable_weights_quantization = enable_weights_quantization
84
-
85
-
86
79
  def create_node_activation_qc(qc: QuantizationConfig,
87
- fw_info: FrameworkInfo) -> NodeActivationQuantizationConfig:
80
+ fw_info: FrameworkInfo,
81
+ use_min_max: bool) -> NodeActivationQuantizationConfig:
88
82
  """
89
83
  Create a activations quantization configuration from a QuantizationConfig object.
90
84
 
@@ -92,6 +86,7 @@ def create_node_activation_qc(qc: QuantizationConfig,
92
86
  qc: QuantizationConfig to create the node's config from.
93
87
  fw_info: Information about the specific framework the node was created from (e.g., whether or not its
94
88
  weights/activations should be quantized)
89
+ use_min_max: Whether the collected min/max statistics should be used when the threshold is computed or not.
95
90
 
96
91
  Returns:
97
92
  Activation quantization configuration of a node.
@@ -102,7 +97,8 @@ def create_node_activation_qc(qc: QuantizationConfig,
102
97
  Logger.critical('Unknown quantization method for activations')
103
98
 
104
99
  activation_quantization_params_fn = get_activation_quantization_params_fn(qc.activation_quantization_method,
105
- qc.activation_threshold_method)
100
+ qc.activation_threshold_method,
101
+ use_min_max)
106
102
 
107
103
  return NodeActivationQuantizationConfig(qc,
108
104
  activation_quantization_fn,
@@ -139,9 +135,10 @@ def create_node_weights_qc(qc: QuantizationConfig,
139
135
  weight_channel_axis)
140
136
 
141
137
 
138
+
142
139
  def _create_node_candidates_weights_qc(qc: QuantizationConfig,
143
- fw_info: FrameworkInfo,
144
- weight_channel_axis: int) -> List[NodeWeightsQuantizationConfig]:
140
+ fw_info: FrameworkInfo,
141
+ weight_channel_axis: int) -> List[NodeWeightsQuantizationConfig]:
145
142
  """
146
143
  Create a list of candidates of weights quantization configurations for a node.
147
144
 
@@ -164,4 +161,4 @@ def _create_node_candidates_weights_qc(qc: QuantizationConfig,
164
161
  else:
165
162
  candidats.append(create_node_weights_qc(qc, fw_info, weight_channel_axis))
166
163
 
167
- return candidats
164
+ return candidats