mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@
17
17
  from tensorflow.keras.layers import ReLU
18
18
 
19
19
  from model_compression_toolkit import common
20
- from model_compression_toolkit.common import Graph, BaseNode
20
+ from model_compression_toolkit.common import Graph, Node
21
21
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher,NodeFrameworkAttrMatcher
22
22
  from model_compression_toolkit.keras.constants import RELU_MAX_VALUE
23
23
  from model_compression_toolkit.common.constants import THRESHOLD
@@ -40,7 +40,7 @@ class RemoveReLUUpperBound(common.BaseSubstitution):
40
40
 
41
41
  def substitute(self,
42
42
  graph: Graph,
43
- node: BaseNode) -> Graph:
43
+ node: Node) -> Graph:
44
44
  """
45
45
  Remove ReLU upper bound if its activation threshold bounds it anyway at
46
46
  the same value.
@@ -21,7 +21,7 @@ import numpy as np
21
21
  from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, Dense, Conv2DTranspose, Activation, ReLU, ZeroPadding2D
22
22
 
23
23
  from model_compression_toolkit import common
24
- from model_compression_toolkit.common import Graph, BaseNode
24
+ from model_compression_toolkit.common import Graph, Node
25
25
  from model_compression_toolkit.common.constants import OUTPUT_SCALE, THRESHOLD
26
26
  from model_compression_toolkit.common.defaultdict import DefaultDict
27
27
  from model_compression_toolkit.common.framework_info import FrameworkInfo
@@ -59,7 +59,7 @@ MATCHER_MID_WITH_PAD = WalkMatcher([op2d_node, mid_activation_nodes, zeropad_nod
59
59
 
60
60
 
61
61
  def scale_reshaping(scale: np.ndarray,
62
- op2d: common.BaseNode,
62
+ op2d: common.Node,
63
63
  kernel_channel_mapping: DefaultDict,
64
64
  in_channels: bool = True) -> np.ndarray:
65
65
  """
@@ -87,8 +87,8 @@ def scale_reshaping(scale: np.ndarray,
87
87
  def update_linear_nodes(graph:Graph,
88
88
  qc: QuantizationConfig,
89
89
  fw_info: FrameworkInfo,
90
- first_op2d_node: BaseNode,
91
- second_op2d_node: BaseNode,
90
+ first_op2d_node: Node,
91
+ second_op2d_node: Node,
92
92
  scale_factor: np.ndarray):
93
93
  """
94
94
  Scale the weights of two linear nodes with a scale factor. Each node is scaled in
@@ -132,7 +132,7 @@ def update_linear_nodes(graph:Graph,
132
132
 
133
133
 
134
134
  def calculate_scale_correction(graph: Graph,
135
- activation_node: BaseNode,
135
+ activation_node: Node,
136
136
  eps: float = 1e-6) -> tuple:
137
137
  """
138
138
  Compute a scale factor by the activation node threshold and its outputs statistics in
@@ -172,9 +172,9 @@ def calculate_scale_correction(graph: Graph,
172
172
  def scale_equalization_lnl(graph: Graph,
173
173
  qc: QuantizationConfig,
174
174
  fw_info: FrameworkInfo,
175
- first_op2d_node: BaseNode,
176
- n_node: BaseNode,
177
- second_op2d_node: BaseNode):
175
+ first_op2d_node: Node,
176
+ n_node: Node,
177
+ second_op2d_node: Node):
178
178
  """
179
179
  Compute a scale factor to scale all activation node's outputs such that
180
180
  its maximum per-channel is the constrained threshold of the activation node.
@@ -235,7 +235,7 @@ class BaseScaleEqualization(common.BaseSubstitution):
235
235
 
236
236
  def substitute(self,
237
237
  graph: Graph,
238
- nodes_list: List[BaseNode]) -> Graph:
238
+ nodes_list: List[Node]) -> Graph:
239
239
  """
240
240
  Scale each channel of the weights of two linear nodes,
241
241
  in order to use the entire constrained range when activations are quantized.
@@ -19,7 +19,7 @@ from tensorflow.keras.layers import SeparableConv2D, Conv2D, DepthwiseConv2D
19
19
  from model_compression_toolkit import common
20
20
  from model_compression_toolkit.common.graph.base_graph import Graph
21
21
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher
22
- from model_compression_toolkit.common.graph.base_node import BaseNode
22
+ from model_compression_toolkit.common.graph.node import Node
23
23
  from model_compression_toolkit.keras.constants import KERNEL, DEPTHWISE_KERNEL, BIAS, KERNEL_SIZE, PADDING, \
24
24
  STRIDES, USE_BIAS, LINEAR, ACTIVATION, TRAINABLE, FILTERS, PAD_VALID
25
25
 
@@ -57,7 +57,7 @@ class SeparableConvDecomposition(common.BaseSubstitution):
57
57
 
58
58
  def substitute(self,
59
59
  graph: Graph,
60
- separable_node: BaseNode) -> Graph:
60
+ separable_node: Node) -> Graph:
61
61
  """
62
62
  Remove a SeparableConv2D node from the graph, and replace it with two equivalent nodes: DepthwiseConv2D
63
63
  and Conv2D. The SeparableConv2D attributes are split to relevant attributes for each node.
@@ -114,28 +114,28 @@ class SeparableConvDecomposition(common.BaseSubstitution):
114
114
  dw_node_name = separable_node.name + '_dw' if not separable_node.reuse else '_'.join(separable_node.name.split('_')[:-2]) + '_dw_' + '_'.join(separable_node.name.split('_')[-2:])
115
115
 
116
116
  # create new nodes
117
- dw_node = common.graph.BaseNode(dw_node_name,
118
- dw_framework_attr,
119
- separable_node.input_shape,
120
- dw_output_shape,
121
- dw_weights_dict,
122
- dw_layer_class,
123
- reuse=separable_node.reuse,
124
- reuse_group=separable_node.reuse_group)
117
+ dw_node = common.graph.Node(dw_node_name,
118
+ dw_framework_attr,
119
+ separable_node.input_shape,
120
+ dw_output_shape,
121
+ dw_weights_dict,
122
+ dw_layer_class,
123
+ reuse=separable_node.reuse,
124
+ reuse_group=separable_node.reuse_group)
125
125
 
126
126
  # If the SeparableConv2D is reused, we need to keep the pointwise node as reused as well,
127
127
  # so we keep the names convention with adding the suffix of "_reuse_X".
128
128
  pw_node_name = separable_node.name + '_pw' if not separable_node.reuse else '_'.join(separable_node.name.split('_')[:-2]) + '_pw_' + '_'.join(separable_node.name.split('_')[-2:])
129
129
 
130
- pw_node = common.graph.BaseNode(pw_node_name,
131
- pw_framework_attr,
132
- pw_input_shape,
133
- separable_node.output_shape,
134
- pw_weights_dict,
135
- pw_layer_class,
136
- reuse=separable_node.reuse,
137
- reuse_group=separable_node.reuse_group
138
- )
130
+ pw_node = common.graph.Node(pw_node_name,
131
+ pw_framework_attr,
132
+ pw_input_shape,
133
+ separable_node.output_shape,
134
+ pw_weights_dict,
135
+ pw_layer_class,
136
+ reuse=separable_node.reuse,
137
+ reuse_group=separable_node.reuse_group
138
+ )
139
139
 
140
140
  graph.add_node(dw_node)
141
141
  graph.add_node(pw_node)
@@ -16,8 +16,6 @@
16
16
  import tensorflow as tf
17
17
 
18
18
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
19
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
20
-
21
19
  if tf.__version__ < "2.6":
22
20
  from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
23
21
  else:
@@ -31,14 +29,13 @@ from tensorflow.keras.layers import Activation, Conv2D, Dense, DepthwiseConv2D,
31
29
  from typing import Tuple, Any
32
30
 
33
31
  from model_compression_toolkit import common
34
- from model_compression_toolkit.common import FrameworkInfo, Graph, BaseNode
35
- from model_compression_toolkit.common.constants import FLOAT_32, DATA_TYPE, THRESHOLD, SIGNED
32
+ from model_compression_toolkit.common import FrameworkInfo, Graph, Node
33
+ from model_compression_toolkit.common.constants import FLOAT_32, DATA_TYPE, THRESHOLD
36
34
  from model_compression_toolkit.common.graph.graph_matchers import EdgeMatcher
37
35
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, \
38
36
  NodeFrameworkAttrMatcher
39
37
 
40
- from model_compression_toolkit.common.quantization.set_node_quantization_config import create_node_activation_qc, \
41
- set_quantization_configs_to_node
38
+ from model_compression_toolkit.common.quantization.set_node_quantization_config import create_node_activation_qc
42
39
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
43
40
  from model_compression_toolkit.common.quantization.quantization_params_generation.qparams_activations_computation \
44
41
  import \
@@ -103,7 +100,7 @@ PAD_NODE = NodeOperationMatcher(ZeroPadding2D)
103
100
 
104
101
  def create_add_node(add_value: float,
105
102
  prev_node_name: str,
106
- input_shape: tuple) -> BaseNode:
103
+ input_shape: tuple) -> Node:
107
104
  """
108
105
  Create a new Add node, with a constant to add.
109
106
  The name of the node is determined by its previous node's name.
@@ -133,13 +130,13 @@ def create_add_node(add_value: float,
133
130
  CONSTANTS: {1: np.array([[[[add_value]]]],
134
131
  dtype=np.float32)}}
135
132
 
136
- add_node = common.graph.BaseNode(add_node_name,
137
- fw_attr,
138
- input_shape,
139
- input_shape,
140
- weights={},
141
- quantization_attr={},
142
- layer_class=TensorFlowOpLayer)
133
+ add_node = common.graph.Node(add_node_name,
134
+ fw_attr,
135
+ input_shape,
136
+ input_shape,
137
+ weights={},
138
+ quantization_attr={},
139
+ layer_class=TensorFlowOpLayer)
143
140
  return add_node
144
141
 
145
142
 
@@ -150,7 +147,7 @@ def create_pad_node(next_node_name: str,
150
147
  pad_top: int,
151
148
  pad_btm: int,
152
149
  pad_left: int,
153
- pad_right: int) -> BaseNode:
150
+ pad_right: int) -> Node:
154
151
  """
155
152
  Create a pad node with a constant value to pad its input tensor.
156
153
 
@@ -192,17 +189,17 @@ def create_pad_node(next_node_name: str,
192
189
  padded_shape = list(input_shape)
193
190
  padded_shape[1] += pad_top + pad_btm
194
191
  padded_shape[2] += pad_left + pad_right
195
- pad_node = common.graph.BaseNode(pad_node_name,
196
- fw_attr,
197
- input_shape,
198
- tuple(padded_shape),
199
- weights={},
200
- quantization_attr={},
201
- layer_class=TensorFlowOpLayer)
192
+ pad_node = common.graph.Node(pad_node_name,
193
+ fw_attr,
194
+ input_shape,
195
+ tuple(padded_shape),
196
+ weights={},
197
+ quantization_attr={},
198
+ layer_class=TensorFlowOpLayer)
202
199
  return pad_node
203
200
 
204
201
 
205
- def compute_op2d_padding(op2d_node: BaseNode) -> Tuple[int, int, int, int]:
202
+ def compute_op2d_padding(op2d_node: Node) -> Tuple[int, int, int, int]:
206
203
  """
207
204
  Compute the padding around an input tensor of a linear node.
208
205
  This is needed to replace tensorflow 'same' padding with actual number of elements to pad.
@@ -231,7 +228,7 @@ def compute_op2d_padding(op2d_node: BaseNode) -> Tuple[int, int, int, int]:
231
228
  return pad_top, pad_btm, pad_left, pad_right
232
229
 
233
230
 
234
- def op2d_bias_correction(op2d_node: common.BaseNode,
231
+ def op2d_bias_correction(op2d_node: common.Node,
235
232
  shift_to_correct: float):
236
233
  """
237
234
  Compute the correction term to add to the op2d node's bias
@@ -269,9 +266,9 @@ def op2d_bias_correction(op2d_node: common.BaseNode,
269
266
 
270
267
 
271
268
  def insert_node_between_two_nodes(graph: Graph,
272
- node_to_insert: BaseNode,
273
- first_node: BaseNode,
274
- last_node: BaseNode):
269
+ node_to_insert: Node,
270
+ first_node: Node,
271
+ last_node: Node):
275
272
  """
276
273
  Insert a new node in a graph between two nodes.
277
274
 
@@ -293,8 +290,8 @@ def insert_node_between_two_nodes(graph: Graph,
293
290
 
294
291
 
295
292
  def insert_node_after_node(graph: Graph,
296
- node_to_insert: BaseNode,
297
- first_node: BaseNode):
293
+ node_to_insert: Node,
294
+ first_node: Node):
298
295
  """
299
296
  Insert a new node to a graph after an existing node in the graph.
300
297
  Check before insertion that the node (that we add the new node after) has
@@ -316,8 +313,8 @@ def insert_node_after_node(graph: Graph,
316
313
 
317
314
 
318
315
  def insert_node_before_node(graph: Graph,
319
- node_to_insert: BaseNode,
320
- last_node: BaseNode):
316
+ node_to_insert: Node,
317
+ last_node: Node):
321
318
  """
322
319
  Insert a new node to a graph before an existing node in the graph.
323
320
  Check before insertion that the node (that we add the new node before) has
@@ -338,9 +335,9 @@ def insert_node_before_node(graph: Graph,
338
335
 
339
336
 
340
337
  def remove_node_between_two_nodes(graph: Graph,
341
- node_to_remove: BaseNode,
342
- first_node: BaseNode,
343
- last_node: BaseNode):
338
+ node_to_remove: Node,
339
+ first_node: Node,
340
+ last_node: Node):
344
341
  """
345
342
  Remove a node from a graph and connect its previous node to
346
343
  its next node after the removal.
@@ -363,12 +360,12 @@ def remove_node_between_two_nodes(graph: Graph,
363
360
  graph.remove_node(node_to_remove)
364
361
 
365
362
 
366
- def shift_negative_function(graph: Graph,
367
- qc: QuantizationConfig,
368
- non_linear_node: BaseNode,
369
- op2d_node: BaseNode,
363
+ def shift_negative_function(graph,
364
+ qc,
365
+ non_linear_node,
366
+ op2d_node,
370
367
  fw_info: FrameworkInfo,
371
- zero_padding_node: BaseNode = None) -> Graph:
368
+ zero_padding_node=None):
372
369
  """
373
370
  Shift the output of a non-linear activation by its minimal output value (quantized) such
374
371
  that all values after the shifting are positive.
@@ -457,15 +454,6 @@ def shift_negative_function(graph: Graph,
457
454
  add_node.output_shape,
458
455
  pad_top, pad_btm, pad_left, pad_right)
459
456
 
460
- # Set quantization configuration to node, even though we do not quantize it:
461
- set_quantization_configs_to_node(fw_info=fw_info,
462
- node=pad_node,
463
- quant_config=qc)
464
-
465
- pad_node.activation_quantization_cfg.enable_activation_quantization = False
466
- for weight_qc in pad_node.candidates_weights_quantization_cfg:
467
- weight_qc.enable_weights_quantization = False
468
-
469
457
  # Insert a pad node between the add node to the op2d, and create statistics for the pad node
470
458
  insert_node_before_node(graph,
471
459
  node_to_insert=pad_node,
@@ -476,32 +464,25 @@ def shift_negative_function(graph: Graph,
476
464
 
477
465
  op2d_node.input_shape = pad_node.output_shape
478
466
 
479
- set_quantization_configs_to_node(fw_info=fw_info,
480
- node=add_node,
481
- quant_config=qc)
482
-
483
- add_node.activation_quantization_cfg.enable_activation_quantization = False
484
-
485
- for weight_qc in add_node.candidates_weights_quantization_cfg:
486
- weight_qc.enable_weights_quantization = False
487
-
488
467
  add_node.activation_quantization_cfg = create_node_activation_qc(qc,
489
- fw_info)
468
+ fw_info,
469
+ add_node_stats_collector.use_min_max)
490
470
 
491
- add_node.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
492
- SIGNED: False})
471
+ add_node.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold})
472
+ add_node.activation_quantization_cfg.activation_is_signed = False
493
473
 
494
474
  if non_linear_node.activation_quantization_cfg.shift_negative_threshold_recalculation:
495
- activation_param = get_activations_qparams(add_node, graph)
496
- assert activation_param.get(SIGNED) == False
475
+ activation_param, activation_is_signed = get_activations_qparams(add_node, graph)
476
+ assert activation_is_signed == False
497
477
  add_node.activation_quantization_cfg.set_activation_quantization_param(activation_param)
478
+ add_node.activation_quantization_cfg.activation_is_signed = False
498
479
 
499
480
  return graph
500
481
 
501
482
 
502
- def get_next_nodes_to_correct(n: BaseNode,
483
+ def get_next_nodes_to_correct(n: Node,
503
484
  graph: Graph,
504
- pad_node_to_consider: BaseNode = None) -> Tuple[Any, Any]:
485
+ pad_node_to_consider: Node = None) -> Tuple[Any, Any]:
505
486
  """
506
487
  Search for the next linear node of a given node. Go over
507
488
  the next nodes of the node and recursively search for a linear node.
@@ -6,11 +6,9 @@ from tensorflow.keras.models import Model
6
6
 
7
7
  from model_compression_toolkit import QuantizationConfig, FrameworkInfo, common, GradientPTQConfig, \
8
8
  MixedPrecisionQuantizationConfig
9
- from model_compression_toolkit.common import Graph, BaseNode
10
- from model_compression_toolkit.common.collectors.statistics_collector import BaseStatsCollector
9
+ from model_compression_toolkit.common import Graph, Node
11
10
  from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
12
11
  from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
13
- from model_compression_toolkit.common.node_prior_info import NodePriorInfo
14
12
  from model_compression_toolkit.common.user_info import UserInformation
15
13
  from model_compression_toolkit.keras.back2framework.model_builder import model_builder
16
14
  from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
@@ -32,10 +30,9 @@ from model_compression_toolkit.keras.graph_substitutions.substitutions.separable
32
30
  SeparableConvDecomposition
33
31
  from model_compression_toolkit.keras.graph_substitutions.substitutions.shift_negative_activation import \
34
32
  apply_shift_negative_correction
35
- from model_compression_toolkit.keras.keras_node_prior_info import create_node_prior_info
36
33
  from model_compression_toolkit.keras.mixed_precision.sensitivity_evaluation import get_sensitivity_evaluation
37
34
  from model_compression_toolkit.keras.reader.reader import model_reader
38
- from model_compression_toolkit.common.collectors.statistics_collector_generator import create_stats_collector_for_node
35
+ from model_compression_toolkit.keras.tensor_marking import get_node_stats_collector
39
36
  import model_compression_toolkit.keras.constants as keras_constants
40
37
 
41
38
 
@@ -121,22 +118,21 @@ class KerasImplementation(FrameworkImplementation):
121
118
  qc,
122
119
  fw_info)
123
120
 
124
- def attach_sc_to_node(self,
125
- node: BaseNode,
126
- output_channel_index: int) -> BaseStatsCollector:
121
+ def attach_sc_to_node(self, node: Node,
122
+ fw_info: FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
127
123
  """
128
124
  Return a statistics collector that should be attached to a node's output
129
125
  during statistics collection.
130
126
 
131
127
  Args:
132
128
  node: Node to return its collector.
133
- output_channel_index: Index of output channels of layers in the model's framework.
129
+ fw_info: FrameworkInfo object with information about the specific framework's model
134
130
 
135
131
  Returns:
136
132
  Statistics collector for the node.
137
133
  """
138
- return create_stats_collector_for_node(node,
139
- output_channel_index=output_channel_index)
134
+ return get_node_stats_collector(node,
135
+ fw_info)
140
136
 
141
137
  def get_substitutions_marking(self) -> List[common.BaseSubstitution]:
142
138
  """
@@ -231,6 +227,7 @@ class KerasImplementation(FrameworkImplementation):
231
227
  gptq_config,
232
228
  fw_info)
233
229
 
230
+
234
231
  def get_sensitivity_evaluation_fn(self,
235
232
  graph: Graph,
236
233
  quant_config: MixedPrecisionQuantizationConfig,
@@ -257,20 +254,3 @@ class KerasImplementation(FrameworkImplementation):
257
254
  metrics_weights,
258
255
  representative_data_gen,
259
256
  fw_info)
260
-
261
- def get_node_prior_info(self,
262
- node: BaseNode,
263
- fw_info: FrameworkInfo) -> NodePriorInfo:
264
- """
265
- Get a NodePriorInfo object for a node that represents a Keras layer.
266
-
267
- Args:
268
- node: Node to get its prior info.
269
- fw_info: Framework specific information needed to create the prior info of the node.
270
-
271
- Returns:
272
- NodePriorInfo with information about the node.
273
- """
274
-
275
- return create_node_prior_info(node=node,
276
- fw_info=fw_info)
@@ -20,7 +20,7 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapp
20
20
  from typing import Callable, List, Any
21
21
 
22
22
  from model_compression_toolkit.common.framework_info import FrameworkInfo
23
- from model_compression_toolkit.common import BaseNode
23
+ from model_compression_toolkit.common import Node
24
24
  from model_compression_toolkit.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
26
26
  MixedPrecisionQuantizationConfig
@@ -207,7 +207,7 @@ def _build_distance_matrix(baseline_tensors: List[Tensor],
207
207
 
208
208
 
209
209
  def _build_baseline_model(graph: Graph,
210
- interest_points: List[BaseNode]) -> Model:
210
+ interest_points: List[Node]) -> Model:
211
211
  """
212
212
  Build a Keras baseline model to compare inferences of the MP model to.
213
213
  The baseline model is the float model we build from the graph.
@@ -33,7 +33,6 @@ if importlib.util.find_spec("tensorflow") is not None\
33
33
  and importlib.util.find_spec("tensorflow_model_optimization") is not None:
34
34
  from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
35
35
  from model_compression_toolkit.keras.keras_implementation import KerasImplementation
36
- from model_compression_toolkit.keras.keras_model_validation import KerasModelValidation
37
36
  from tensorflow.keras.models import Model
38
37
 
39
38
  def keras_post_training_quantization(in_model: Model,
@@ -85,8 +84,7 @@ if importlib.util.find_spec("tensorflow") is not None\
85
84
  >>> quantized_model, quantization_info = mct.keras_post_training_quantization(model, repr_datagen)
86
85
 
87
86
  """
88
- KerasModelValidation(model=in_model,
89
- fw_info=fw_info).validate()
87
+
90
88
  return post_training_quantization(in_model,
91
89
  representative_data_gen,
92
90
  n_iter,
@@ -168,8 +166,6 @@ if importlib.util.find_spec("tensorflow") is not None\
168
166
  For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
169
167
 
170
168
  """
171
- KerasModelValidation(model=in_model,
172
- fw_info=fw_info).validate()
173
169
 
174
170
  if target_kpi is None:
175
171
  common.Logger.warning("No KPI was passed. Using non mixed-precision compression process...")
@@ -21,7 +21,7 @@ import tensorflow as tf
21
21
  import numpy as np
22
22
  from tensorflow.python.util.object_identity import Reference as TFReference
23
23
 
24
- from model_compression_toolkit.common.constants import THRESHOLD, SIGNED
24
+ from model_compression_toolkit.common.constants import THRESHOLD
25
25
 
26
26
 
27
27
  def quantizer_min_max_calculator(threshold: np.ndarray,
@@ -52,6 +52,7 @@ def quantizer_min_max_calculator(threshold: np.ndarray,
52
52
 
53
53
 
54
54
  def constraint_quantization(activation_n_bits: int,
55
+ activation_is_signed: bool,
55
56
  quantization_params: dict) -> Callable:
56
57
  """
57
58
  Use a NodeQuantizationConfig to compute a quantizer min/max values, and use it to
@@ -59,15 +60,14 @@ def constraint_quantization(activation_n_bits: int,
59
60
 
60
61
  Args:
61
62
  activation_n_bits: Number of bits to use for quantization.
63
+ activation_is_signed: Whether the quantization range should include negative values or not.
62
64
  quantization_params: Dictionary of specific parameters for this quantization function.
63
65
 
64
66
  Returns:
65
67
  A fake quantization node.
66
68
  """
67
69
  activation_threshold = quantization_params.get(THRESHOLD)
68
- activation_is_signed = quantization_params.get(SIGNED)
69
-
70
- if activation_threshold is None or activation_is_signed is None:
70
+ if activation_threshold is None:
71
71
  return None
72
72
 
73
73
  min_value, max_value = quantizer_min_max_calculator(activation_threshold,
@@ -23,7 +23,7 @@ from model_compression_toolkit.keras.quantizer.base_quantizer import BaseTrainab
23
23
  from model_compression_toolkit.keras.quantizer.gradient_ptq.utils import symmetric_quantizer
24
24
  from model_compression_toolkit.keras.quantizer.gradient_ptq.utils import ste_round
25
25
  from model_compression_toolkit import common
26
- from model_compression_toolkit.common.constants import THRESHOLD, SIGNED
26
+ from model_compression_toolkit.common.constants import THRESHOLD
27
27
 
28
28
 
29
29
  class TrainableQuantizer(BaseTrainableQuantizer):
@@ -171,8 +171,7 @@ class TrainableQuantizer(BaseTrainableQuantizer):
171
171
 
172
172
  threshold_change = np.asarray(new_threshold / old_threshold).flatten()
173
173
  common.Logger.info(f"Layer '{layer.layer.name}' has total threshold change of {str(threshold_change)}")
174
- return {THRESHOLD: new_threshold.numpy().reshape(self.threshold_shape),
175
- SIGNED: self.signed}
174
+ return {THRESHOLD: new_threshold.numpy().reshape(self.threshold_shape)}
176
175
 
177
176
  def get_trainable_parameters(self):
178
177
  """
@@ -17,15 +17,10 @@ from typing import List, Any, Dict
17
17
 
18
18
  from tensorflow.python.training.tracking.data_structures import ListWrapper
19
19
 
20
- from model_compression_toolkit.common.constants import THRESHOLD, SIGNED
20
+ from model_compression_toolkit.common.constants import THRESHOLD
21
21
  from model_compression_toolkit.keras.quantizer.gradient_ptq.activation_quantizer import TrainableQuantizer
22
22
  from model_compression_toolkit.keras.quantizer.gradient_ptq.base_quantizer_gptq_config import BaseQuantizeConfig
23
- import tensorflow as tf
24
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
25
- if tf.__version__ < "2.6":
26
- from tensorflow.python.keras.layers import Layer
27
- else:
28
- from keras.engine.base_layer import Layer
23
+ from tensorflow.python.keras.layers import Layer
29
24
  from tensorflow import Tensor
30
25
  from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
31
26
 
@@ -37,16 +32,17 @@ class ActivationQuantizeConfig(BaseQuantizeConfig):
37
32
 
38
33
  def __init__(self,
39
34
  activation_quantization_params: dict,
35
+ signed: bool,
40
36
  num_bits: int = 8):
41
37
  """
42
38
  Initialize a TrainableQuantizer and set as the activation quantizer.
43
39
 
44
40
  Args:
45
41
  activation_quantization_params: Parameters to use for quantization.
42
+ signed: Quantization range is signed or unsigned.
46
43
  num_bits: Number of bits to use for quantization.
47
44
  """
48
45
  threshold_values = activation_quantization_params.get(THRESHOLD)
49
- signed = activation_quantization_params.get(SIGNED)
50
46
  self.activation_quantizer = TrainableQuantizer(num_bits=num_bits,
51
47
  per_axis=False,
52
48
  threshold_values=threshold_values,
@@ -16,16 +16,11 @@
16
16
  from typing import List, Tuple, Any, Dict
17
17
 
18
18
  from tensorflow import Tensor
19
- import tensorflow as tf
20
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
21
- if tf.__version__ < "2.6":
22
- from tensorflow.python.keras.layers import Layer
23
- else:
24
- from keras.engine.base_layer import Layer
19
+ from tensorflow.python.keras.layers import Layer
25
20
  from tensorflow.python.training.tracking.data_structures import ListWrapper
26
21
  from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer
27
22
 
28
- from model_compression_toolkit.common.constants import THRESHOLD, SIGNED
23
+ from model_compression_toolkit.common.constants import THRESHOLD
29
24
  from model_compression_toolkit.keras.quantizer.gradient_ptq.weight_quantizer import TrainableWeightQuantizer
30
25
  from model_compression_toolkit.keras.quantizer.gradient_ptq.activation_quantizer import TrainableQuantizer
31
26
  from model_compression_toolkit.keras.quantizer.gradient_ptq.base_quantizer_gptq_config import BaseQuantizeConfig
@@ -44,6 +39,7 @@ class ActivationAndWeightQuantizeConfig(BaseQuantizeConfig):
44
39
  weight_channel_axis: int,
45
40
  weight_num_bits: int,
46
41
  activation_quantization_params: dict,
42
+ activation_signed: bool,
47
43
  activation_num_bits: int = 8,
48
44
  max_lsbs_change: int = 8):
49
45
  """
@@ -54,13 +50,12 @@ class ActivationAndWeightQuantizeConfig(BaseQuantizeConfig):
54
50
  weight_channel_axis: Channel index to quantize when quantizing the weight per-channel.
55
51
  weight_num_bits: Number of bits to use for weight quantization.
56
52
  activation_quantization_params: Parameters to use for the activation quantization.
53
+ activation_signed: Quantization range is signed or unsigned.
57
54
  activation_num_bits: Number of bits to use for quantization of the activation.
58
55
 
59
56
  """
60
57
 
61
58
  activation_threshold_values = activation_quantization_params.get(THRESHOLD)
62
- activation_signed = activation_quantization_params.get(SIGNED)
63
-
64
59
  self.activation_quantizer = TrainableQuantizer(num_bits=activation_num_bits,
65
60
  per_axis=False,
66
61
  threshold_values=activation_threshold_values,