mct-nightly 1.1.0.6012022.post2521__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.6012022.post2521.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -20,17 +20,20 @@ from typing import Any, Tuple
20
20
 
21
21
  import numpy as np
22
22
 
23
- from model_compression_toolkit.common.framework_info import FrameworkInfo, ChannelAxis
24
23
  from model_compression_toolkit.common.collectors.histogram_collector import HistogramCollector
25
24
  from model_compression_toolkit.common.collectors.mean_collector import MeanCollector
26
25
  from model_compression_toolkit.common.collectors.min_max_per_channel_collector import MinMaxPerChannelCollector
27
26
 
28
27
 
29
- class BaseStatsCollector(object):
28
+ class BaseStatsContainer(object):
30
29
  """
31
- Base class for statistics collection (contains multiple collectors such as mean collector,
30
+ Base class for statistics collection container (contain multiple statistics collector such as mean collector,
32
31
  histogram collector, etc.).
33
32
  """
33
+ def __init__(self):
34
+ # Disable histogram collection. Enable in specific collectors if needed
35
+ self.collect_histogram = False
36
+ self.use_min_max = False
34
37
 
35
38
  def require_collection(self) -> bool:
36
39
  """
@@ -50,13 +53,12 @@ class BaseStatsCollector(object):
50
53
  raise Exception(f'update_statistics is not implemented in {self.__class__.__name__}')
51
54
 
52
55
 
53
- class StatsCollector(BaseStatsCollector):
56
+ class StatsContainer(BaseStatsContainer):
54
57
  """
55
58
  Class to wrap all statistics that are being collected for an input/output node.
56
59
  """
57
60
 
58
61
  def __init__(self,
59
- output_channel_index: ChannelAxis,
60
62
  init_min_value: float = None,
61
63
  init_max_value: float = None):
62
64
  """
@@ -64,17 +66,18 @@ class StatsCollector(BaseStatsCollector):
64
66
  Set initial min/max values if are known.
65
67
 
66
68
  Args:
67
- output_channel_index: Index of output channels.
68
69
  init_min_value: Initial min value for min/max stored values.
69
70
  init_max_value: Initial max value for min/max stored values.
70
71
  """
71
72
 
72
73
  super().__init__()
73
- self.hc = HistogramCollector()
74
- self.mc = MeanCollector(axis=output_channel_index)
74
+ self.use_min_max = is_number(init_min_value) and is_number(init_max_value)
75
+ self.collect_histogram = True
76
+ if self.collect_histogram:
77
+ self.hc = HistogramCollector()
78
+ self.mc = MeanCollector()
75
79
  self.mpcc = MinMaxPerChannelCollector(init_min_value=init_min_value,
76
- init_max_value=init_max_value,
77
- axis=output_channel_index)
80
+ init_max_value=init_max_value)
78
81
 
79
82
  def update_statistics(self, x: Any):
80
83
  """
@@ -85,7 +88,8 @@ class StatsCollector(BaseStatsCollector):
85
88
  """
86
89
 
87
90
  x = standardize_tensor(x)
88
- self.hc.update(x)
91
+ if self.collect_histogram:
92
+ self.hc.update(x)
89
93
  self.mc.update(x)
90
94
  self.mpcc.update(x)
91
95
 
@@ -139,7 +143,7 @@ class StatsCollector(BaseStatsCollector):
139
143
  return True
140
144
 
141
145
 
142
- class NoStatsCollector(BaseStatsCollector):
146
+ class NoStatsContainer(BaseStatsContainer):
143
147
  """
144
148
  Class that inherits from base tensor.
145
149
  Indicating that for a point in a graph we should not gather statistics.
@@ -207,51 +211,51 @@ def standardize_tensor(x: Any) -> np.ndarray:
207
211
  return x
208
212
 
209
213
 
210
- def shift_statistics(collector: BaseStatsCollector,
211
- shift_value: np.ndarray) -> BaseStatsCollector:
214
+ def shift_statistics(collector: BaseStatsContainer,
215
+ shift_value: np.ndarray) -> BaseStatsContainer:
212
216
  """
213
- Shift all statistics in collectors of a statistics collector by a
217
+ Shift all statistics in collectors of a statistics container by a
214
218
  value (or a value per-channel).
215
219
 
216
220
  Args:
217
- collector: Statistics collector to shift its collectors.
221
+ collector: Statistics container to shift its collectors.
218
222
  shift_value: Value to shift all statistics by.
219
223
 
220
224
  Returns:
221
- New copy of the collector with shifted statistics.
225
+ New copy of the container with shifted statistics.
222
226
 
223
227
  """
224
228
 
225
229
  shifted_collector = deepcopy(collector)
226
- if isinstance(collector, StatsCollector):
230
+ if isinstance(collector, StatsContainer):
227
231
  shifted_collector.mpcc.shift(shift_value)
228
232
  shifted_collector.mc.shift(shift_value)
229
- if shifted_collector.require_collection():
233
+ if shifted_collector.collect_histogram:
230
234
  shifted_collector.hc.shift(shift_value)
231
235
 
232
236
  return shifted_collector
233
237
 
234
238
 
235
- def scale_statistics(collector: BaseStatsCollector,
236
- scale_value: np.ndarray) -> BaseStatsCollector:
239
+ def scale_statistics(collector: BaseStatsContainer,
240
+ scale_value: np.ndarray) -> BaseStatsContainer:
237
241
  """
238
- Scale all statistics in collectors of a statistics collector
242
+ Scale all statistics in collectors of a statistics container
239
243
  by a factor (or a factor per-channel).
240
244
 
241
245
  Args:
242
- collector: Statistics collector to shift its collectors.
246
+ collector: Statistics container to shift its collectors.
243
247
  scale_value: Value to shift all statistics by.
244
248
 
245
249
  Returns:
246
- New copy of the collector with scaled statistics.
250
+ New copy of the container with scaled statistics.
247
251
 
248
252
  """
249
253
 
250
254
  scaled_collector = deepcopy(collector)
251
- if isinstance(collector, StatsCollector):
255
+ if isinstance(collector, StatsContainer):
252
256
  scaled_collector.mpcc.scale(scale_value)
253
257
  scaled_collector.mc.scale(scale_value)
254
- if scaled_collector.require_collection():
258
+ if scaled_collector.collect_histogram:
255
259
  scaled_collector.hc.scale(scale_value)
256
260
 
257
261
  return scaled_collector
@@ -29,10 +29,10 @@ from tensorboard.compat.proto.summary_pb2 import HistogramProto
29
29
  from tensorboard.compat.proto.summary_pb2 import Summary
30
30
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
31
31
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
32
- from typing import List, Any, Dict
32
+ from typing import List, Any, Dict, Callable
33
33
 
34
- from model_compression_toolkit.common import Graph, BaseNode
35
- from model_compression_toolkit.common.collectors.statistics_collector import BaseStatsCollector
34
+ from model_compression_toolkit.common import Graph, Node
35
+ from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
36
36
 
37
37
  DEVICE_STEP_STATS = "/device:CPU:0"
38
38
 
@@ -138,7 +138,7 @@ class TensorboardWriter(object):
138
138
  bucket_limit=bins.tolist(),
139
139
  bucket=counts.tolist())
140
140
 
141
- def __create_histo_event(statistics_collector: BaseStatsCollector):
141
+ def __create_histo_event(statistics_collector: BaseStatsContainer):
142
142
  """
143
143
  Create an event of histogram, and attach it to a list of events outside
144
144
  the scope called 'events'.
@@ -186,7 +186,7 @@ class TensorboardWriter(object):
186
186
 
187
187
  """
188
188
 
189
- def __get_node_attr(n: BaseNode) -> Dict[str, Any]:
189
+ def __get_node_attr(n: Node) -> Dict[str, Any]:
190
190
  """
191
191
  Create a dictionary to display as the node's attributes.
192
192
  The dictionary contains information from node's framework attributes, quantization attributes
@@ -203,10 +203,7 @@ class TensorboardWriter(object):
203
203
  if n.quantization_attr is not None:
204
204
  attr.update(n.quantization_attr)
205
205
 
206
- # To log quantization configurations we need to check
207
- # if they exist at all, as we can log the initial graph,
208
- # which its nodes do not have configurations yet.
209
- # Log final config or unified candidates, not both
206
+ # log final config or unified candidates, not both
210
207
  if n.final_weights_quantization_cfg is not None:
211
208
  attr.update(n.final_weights_quantization_cfg.__dict__)
212
209
  elif n.candidates_weights_quantization_cfg is not None:
@@ -216,7 +213,7 @@ class TensorboardWriter(object):
216
213
  attr.update(n.activation_quantization_cfg.__dict__)
217
214
  return attr
218
215
 
219
- def __get_node_output_dims(n: BaseNode) -> List[tuple]:
216
+ def __get_node_output_dims(n: Node) -> List[tuple]:
220
217
  """
221
218
  Get node's output shapes. If the first dimension in an output shape is None,
222
219
  it means the batch size is dynamic, and it's replaced with -1 to mark it.
@@ -240,7 +237,7 @@ class TensorboardWriter(object):
240
237
  dims = [(-1,) + output_shape[1:] if output_shape[0] is None else output_shape]
241
238
  return dims
242
239
 
243
- def __create_node_stats(n: BaseNode):
240
+ def __create_node_stats(n: Node):
244
241
  """
245
242
  Create a NodeExecStats for a node in the graph. A NodeExecStats contains the
246
243
  memory and compute time a node requires.
@@ -21,7 +21,7 @@ from networkx.algorithms.dag import topological_sort
21
21
 
22
22
  from tensorflow.keras.layers import Layer
23
23
  from model_compression_toolkit import common
24
- from model_compression_toolkit.common import Graph, BaseNode
24
+ from model_compression_toolkit.common import Graph, Node
25
25
  from model_compression_toolkit.keras.constants import LAYER_NAME
26
26
 
27
27
 
@@ -35,7 +35,7 @@ class OperationHandler(object):
35
35
  self.node_to_fw_op_dict = instance_builder(self.node_sort) # hold dictionary from node to its equivalent
36
36
  # Keras layer
37
37
 
38
- def get_node_op_function(self, n: BaseNode) -> Layer:
38
+ def get_node_op_function(self, n: Node) -> Layer:
39
39
  """
40
40
  Get the Keras layer that was built from the passed node.
41
41
 
@@ -58,7 +58,7 @@ class OperationHandler(object):
58
58
  return op_func
59
59
 
60
60
 
61
- def node_builder(n: common.BaseNode) -> Layer:
61
+ def node_builder(n: common.Node) -> Layer:
62
62
  """
63
63
  Build a Keras layer from a node.
64
64
 
@@ -78,7 +78,7 @@ def node_builder(n: common.BaseNode) -> Layer:
78
78
  return node_instance
79
79
 
80
80
 
81
- def instance_builder(toposort: List[BaseNode]) -> Dict[BaseNode, Layer]:
81
+ def instance_builder(toposort: List[Node]) -> Dict[Node, Layer]:
82
82
  """
83
83
  Build a dictionary of nodes to their corresponding Keras
84
84
  layers, given a list of nodes.
@@ -20,26 +20,24 @@ import tensorflow as tf
20
20
  if tf.__version__ < "2.6":
21
21
  from tensorflow.keras.layers import Input
22
22
  from tensorflow.python.keras.layers.core import TFOpLambda
23
- from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
24
- from tensorflow.python.keras.layers import Layer
25
23
  else:
26
24
  from keras import Input
27
25
  from keras.layers.core import TFOpLambda
28
- from keras.engine.base_layer import TensorFlowOpLayer, Layer
29
26
 
30
27
  from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
28
+ from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
29
+ from tensorflow.python.keras.layers import Layer
31
30
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
32
31
  from typing import Tuple, Any, Dict, List
33
32
  from tensorflow.python.util.object_identity import Reference as TFReference
34
- from model_compression_toolkit.common.graph.functional_node import FunctionalNode
33
+
35
34
 
36
35
  from model_compression_toolkit import common
37
36
  from model_compression_toolkit.common.framework_info import FrameworkInfo
38
37
  from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
39
- from model_compression_toolkit.keras.quantizer.mixed_precision.quantization_config_factory import \
40
- quantization_config_builder_mixed_precision
38
+ from model_compression_toolkit.keras.quantizer.mixed_precision.quantization_config_factory import quantization_config_builder_mixed_precision
41
39
  from model_compression_toolkit.keras.quantizer.gradient_ptq.config_factory import quantization_config_builder_gptq
42
- from model_compression_toolkit.common import BaseNode, Graph
40
+ from model_compression_toolkit.common import Node, Graph
43
41
  from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX
44
42
  from model_compression_toolkit.keras.back2framework.instance_builder import OperationHandler
45
43
  from model_compression_toolkit.keras.reader.connectivity_handler import OutTensor
@@ -82,9 +80,9 @@ def is_layer_fake_quant(layer: Layer) -> bool:
82
80
  isinstance(layer, TFOpLambda) and layer.symbol == FQ_NODE_OP_V2_4)
83
81
 
84
82
 
85
- def build_input_tensors_list(node: BaseNode,
83
+ def build_input_tensors_list(node: Node,
86
84
  graph: Graph,
87
- node_to_output_tensors_dict: Dict[BaseNode, List[TFReference]]) -> List[List[TFReference]]:
85
+ node_to_output_tensors_dict: Dict[Node, List[TFReference]]) -> List[List[TFReference]]:
88
86
  """
89
87
  Given a node, build a list of input tensors the node gets. The list is built
90
88
  based on the node's incoming edges and previous nodes' output tensors.
@@ -107,10 +105,10 @@ def build_input_tensors_list(node: BaseNode,
107
105
  return input_tensors
108
106
 
109
107
 
110
- def run_operation(n: BaseNode,
108
+ def run_operation(n: Node,
111
109
  input_tensors: List[List[TFReference]],
112
110
  op_func: Layer,
113
- input_nodes_to_input_tensors: Dict[BaseNode, Any],
111
+ input_nodes_to_input_tensors: Dict[Node, Any],
114
112
  mode: ModelBuilderMode = ModelBuilderMode.QUANTIZED) -> List[TFReference]:
115
113
  """
116
114
  Applying the layer (op_func) to the input tensors (input_tensors).
@@ -131,43 +129,33 @@ def run_operation(n: BaseNode,
131
129
 
132
130
  if len(input_tensors) == 0: # Placeholder handling
133
131
  out_tensors_of_n = input_nodes_to_input_tensors[n]
134
- if n.is_activation_quantization_enabled():
135
- if mode in [ModelBuilderMode.QUANTIZED, ModelBuilderMode.GPTQ, ModelBuilderMode.MIXEDPRECISION]:
136
- # Adding a fake quant node to Input when in GPTQ mode because quantize_model doesn't quantize the
137
- # input layer
138
- fake_quant = n.activation_quantization_cfg.generate_quantization_node()
139
-
140
- if fake_quant is None:
141
- raise Exception(f'{n.name} should be quantized, but activation quantization function is None')
142
-
132
+ if mode in [ModelBuilderMode.QUANTIZED, ModelBuilderMode.GPTQ, ModelBuilderMode.MIXEDPRECISION]:
133
+ # Adding a fake quant node to Input when in GPTQ mode because quantize_model doesn't quantize the input layer
134
+ assert n.activation_quantization_cfg is not None # Input layers should always have activation config
135
+ fake_quant = n.activation_quantization_cfg.activation_quantization_fn(n.activation_quantization_cfg.activation_n_bits,
136
+ n.activation_quantization_cfg.activation_is_signed,
137
+ n.activation_quantization_cfg.activation_quantization_params)
138
+ if fake_quant is not None:
143
139
  out_tensors_of_n = fake_quant(out_tensors_of_n)
144
-
140
+
145
141
  else:
146
142
  input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
147
- # Build a functional node using its args
148
- if isinstance(n, FunctionalNode):
149
- if n.inputs_as_list: # If the first argument should be a list of tensors:
150
- out_tensors_of_n = op_func(input_tensors, *n.op_call_args, **n.op_call_kwargs)
151
- else: # If the input tensors should not be a list but iterated:
152
- out_tensors_of_n = op_func(*input_tensors, *n.op_call_args, **n.op_call_kwargs)
143
+
144
+ # If operator expects a single input tensor, it cannot be a list as it should
145
+ # have a dtype field.
146
+ if len(input_tensors) == 1:
147
+ out_tensors_of_n = op_func(input_tensors[0], **n.op_call_args)
153
148
  else:
154
- # If operator expects a single input tensor, it cannot be a list as it should
155
- # have a dtype field.
156
- if len(input_tensors) == 1:
157
- input_tensors = input_tensors[0]
158
- out_tensors_of_n = op_func(input_tensors)
149
+ out_tensors_of_n = op_func(input_tensors, **n.op_call_args)
159
150
 
160
151
  # Add a fake quant node if the node has an activation threshold.
161
- if n.is_activation_quantization_enabled():
162
- if mode in [ModelBuilderMode.QUANTIZED,
163
- ModelBuilderMode.MIXEDPRECISION]:
164
-
165
- fake_quant = n.activation_quantization_cfg.generate_quantization_node()
166
-
167
- if fake_quant is None:
168
- raise Exception(f'{n.name} should be quantized, but activation quantization function is None')
169
-
170
- out_tensors_of_n = fake_quant(out_tensors_of_n)
152
+ if n.activation_quantization_cfg is not None:
153
+ if mode in [ModelBuilderMode.QUANTIZED, ModelBuilderMode.MIXEDPRECISION] and n.activation_quantization_cfg.enable_activation_quantization:
154
+ fake_quant = n.activation_quantization_cfg.activation_quantization_fn(n.activation_quantization_cfg.activation_n_bits,
155
+ n.activation_quantization_cfg.activation_is_signed,
156
+ n.activation_quantization_cfg.activation_quantization_params)
157
+ if fake_quant is not None:
158
+ out_tensors_of_n = fake_quant(out_tensors_of_n)
171
159
 
172
160
  return out_tensors_of_n
173
161
 
@@ -277,11 +265,10 @@ def model_builder(graph: common.Graph,
277
265
  nodes = graph.find_node_by_name(get_node_name_from_layer(layer))
278
266
  if len(nodes) == 1:
279
267
  node = nodes[0]
280
- # Wrap only if its weights should be quantized
281
- if node.is_weights_quantization_enabled():
282
- return QuantizeWrapper(layer, quantization_config_builder_mixed_precision(node, fw_info))
283
- return layer
284
-
268
+ # does not need to get wrapped as its weights are not quantized
269
+ if node.candidates_weights_quantization_cfg is None:
270
+ return layer
271
+ return QuantizeWrapper(layer, quantization_config_builder_mixed_precision(node, fw_info))
285
272
  elif is_layer_fake_quant(layer):
286
273
  return layer
287
274
  else:
@@ -38,9 +38,6 @@ PAD_SAME = 'same'
38
38
  RELU_MAX_VALUE = 'max_value'
39
39
  THRESHOLD = 'threshold'
40
40
  NEGATIVE_SLOPE = 'negative_slope'
41
- CHANNELS_FORMAT = 'data_format'
42
- CHANNELS_FORMAT_FIRST = 'channels_first'
43
- CHANNELS_FORMAT_LAST = 'channels_last'
44
41
 
45
42
  # Layers variables names:
46
43
  KERNEL = 'kernel'
@@ -14,18 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import tensorflow as tf
18
- if tf.__version__ < "2.6":
19
- from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Reshape, ZeroPadding2D, Dropout, \
20
- MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, InputLayer, \
21
- Concatenate, Softmax, PReLU, Flatten, Cropping2D
22
- else:
23
- from keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Reshape, ZeroPadding2D, \
24
- Dropout, MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, \
25
- InputLayer, Concatenate, Softmax, PReLU, Flatten, Cropping2D
17
+ from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Conv2DTranspose, Reshape, ZeroPadding2D, Dropout, \
18
+ MaxPooling2D, Activation, ReLU, GlobalAveragePooling2D, Add, Multiply, AveragePooling2D, UpSampling2D, InputLayer, \
19
+ Concatenate, Softmax, PReLU, Flatten, Cropping2D
26
20
 
27
21
  from model_compression_toolkit.common.defaultdict import DefaultDict
28
- from model_compression_toolkit.common.framework_info import FrameworkInfo, ChannelAxis
22
+ from model_compression_toolkit.common.framework_info import FrameworkInfo
29
23
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationMethod
30
24
  from model_compression_toolkit.common.quantization.quantizers.kmeans_quantizer import kmeans_quantizer
31
25
  from model_compression_toolkit.common.quantization.quantizers.lut_kmeans_quantizer import lut_kmeans_quantizer
@@ -51,9 +45,7 @@ NO_QUANTIZATION = [Reshape,
51
45
  Cropping2D,
52
46
  ZeroPadding2D,
53
47
  Dropout,
54
- MaxPooling2D,
55
- tf.split,
56
- tf.quantization.fake_quant_with_min_max_vars] # TODO: replace with marking
48
+ MaxPooling2D] # TODO: replace with marking
57
49
 
58
50
  ACTIVATION = [Activation,
59
51
  ReLU,
@@ -65,19 +57,7 @@ ACTIVATION = [Activation,
65
57
  UpSampling2D,
66
58
  InputLayer,
67
59
  Concatenate,
68
- PReLU,
69
- tf.add,
70
- tf.multiply,
71
- tf.reduce_mean,
72
- tf.reduce_min,
73
- tf.reduce_sum,
74
- tf.reduce_max,
75
- tf.image.resize,
76
- tf.image.crop_and_resize,
77
- tf.concat]
78
-
79
-
80
-
60
+ PReLU]
81
61
 
82
62
  """
83
63
  Map each layer to a list of its' weights attributes that should get quantized.
@@ -131,11 +111,6 @@ WEIGHTS_QUANTIZER_MAPPING = {QuantizationMethod.POWER_OF_TWO: power_of_two_quant
131
111
  QuantizationMethod.KMEANS: kmeans_quantizer,
132
112
  QuantizationMethod.LUT_QUANTIZER: lut_kmeans_quantizer}
133
113
 
134
- """
135
- Output channel index of the model's layers
136
- """
137
- OUTPUT_CHANNEL_INDEX = ChannelAxis.NHWC
138
-
139
114
  DEFAULT_KERAS_INFO = FrameworkInfo(KERNEL_OPS,
140
115
  ACTIVATION,
141
116
  NO_QUANTIZATION,
@@ -144,5 +119,4 @@ DEFAULT_KERAS_INFO = FrameworkInfo(KERNEL_OPS,
144
119
  DEFAULT_CHANNEL_AXIS_DICT,
145
120
  ACTIVATION2MINMAX,
146
121
  LAYER2MINMAX,
147
- KERNEL_ATTRIBUTES,
148
- OUTPUT_CHANNEL_INDEX)
122
+ KERNEL_ATTRIBUTES)
@@ -20,14 +20,14 @@ from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapp
20
20
  from typing import Tuple, List
21
21
 
22
22
  from model_compression_toolkit.common.graph.base_graph import Graph
23
- from model_compression_toolkit.common.graph.base_node import BaseNode
23
+ from model_compression_toolkit.common.graph.node import Node
24
24
  from model_compression_toolkit.keras.constants import USE_BIAS
25
25
  from model_compression_toolkit.keras.quantizer.gradient_ptq import ActivationQuantizeConfig, WeightQuantizeConfig, ActivationAndWeightQuantizeConfig
26
26
  from model_compression_toolkit.common.framework_info import FrameworkInfo
27
27
  from tensorflow.keras.models import Model
28
28
 
29
29
 
30
- def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
30
+ def get_compare_points(input_graph: Graph) -> Tuple[List[Node], List[str]]:
31
31
  """
32
32
  Create a list of nodes with weights in a graph and a corresponding list
33
33
  of their names for tensors comparison purposes.
@@ -15,14 +15,8 @@
15
15
 
16
16
 
17
17
  import copy
18
- import tensorflow as tf
19
-
20
- # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
21
- if tf.__version__ < "2.6":
22
- from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
23
- else:
24
- from keras.engine.base_layer import TensorFlowOpLayer
25
18
 
19
+ from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer
26
20
  from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
27
21
 
28
22
  from model_compression_toolkit import common
@@ -57,6 +57,7 @@ def gptq_training_wrapper(tg: Graph,
57
57
  #########################################
58
58
  # Build two models and compare points
59
59
  #########################################
60
+ # TODO: maybe need to add pre_build substitutions here. Ask Elad
60
61
  compare_points, _ = get_compare_points(tg) # get compare points
61
62
  n = len(compare_points)
62
63
  float_model, float_user_info = model_builder(tg,
@@ -21,7 +21,7 @@ from model_compression_toolkit.common.constants import FLOAT_32, DATA_TYPE
21
21
  from model_compression_toolkit.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, \
23
23
  NodeFrameworkAttrMatcher
24
- from model_compression_toolkit.common.graph.base_node import BaseNode
24
+ from model_compression_toolkit.common.graph.node import Node
25
25
  from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION, TRAINABLE, LAYER_NAME
26
26
 
27
27
 
@@ -49,7 +49,7 @@ class ActivationDecomposition(common.BaseSubstitution):
49
49
 
50
50
  def substitute(self,
51
51
  graph: Graph,
52
- op2d_node: BaseNode) -> Graph:
52
+ op2d_node: Node) -> Graph:
53
53
  """
54
54
  Decompose the activation function in a linear node to a new activation layer.
55
55
  Set activation function in the linear node to 'linear' (y=x).
@@ -70,14 +70,12 @@ class ActivationDecomposition(common.BaseSubstitution):
70
70
  DATA_TYPE: FLOAT_32,
71
71
  ACTIVATION: op2d_node.framework_attr.get(ACTIVATION)}
72
72
 
73
- activation_node = common.graph.BaseNode(activation_node_name,
74
- activation_fw_attr,
75
- op2d_node.output_shape,
76
- op2d_node.output_shape,
77
- {},
78
- Activation)
79
-
80
-
73
+ activation_node = common.graph.Node(activation_node_name,
74
+ activation_fw_attr,
75
+ op2d_node.output_shape,
76
+ op2d_node.output_shape,
77
+ {},
78
+ Activation)
81
79
 
82
80
  graph.add_node(activation_node)
83
81
  graph.reconnect_out_edges(current_node=op2d_node,
@@ -24,7 +24,7 @@ from model_compression_toolkit import common
24
24
  from model_compression_toolkit.common.graph.base_graph import Graph
25
25
  from model_compression_toolkit.common.graph.graph_matchers import EdgeMatcher, NodeOperationMatcher, \
26
26
  NodeFrameworkAttrMatcher
27
- from model_compression_toolkit.common.graph.base_node import BaseNode
27
+ from model_compression_toolkit.common.graph.node import Node
28
28
  from model_compression_toolkit.keras.constants import KERNEL, BIAS, USE_BIAS, LINEAR, ACTIVATION, LAYER_NAME, \
29
29
  GAMMA, BETA, EPSILON, \
30
30
  MOVING_MEAN, \
@@ -51,7 +51,7 @@ class BatchNormalizationFolding(common.BaseSubstitution):
51
51
 
52
52
  def substitute(self,
53
53
  graph: Graph,
54
- edge_nodes: Tuple[BaseNode, BaseNode]) -> Graph:
54
+ edge_nodes: Tuple[Node, Node]) -> Graph:
55
55
  """
56
56
  Fold BatchNormalization into preceding linear layers.
57
57
 
@@ -21,7 +21,7 @@ from model_compression_toolkit import common
21
21
  from model_compression_toolkit.common.framework_info import FrameworkInfo
22
22
  from model_compression_toolkit.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, WalkMatcher
24
- from model_compression_toolkit.common.graph.base_node import BaseNode
24
+ from model_compression_toolkit.common.graph.node import Node
25
25
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
26
26
  from model_compression_toolkit.common.constants import THRESHOLD
27
27
  from model_compression_toolkit.keras.constants import KERNEL
@@ -62,7 +62,7 @@ class BaseInputScaling(common.BaseSubstitution):
62
62
 
63
63
  def substitute(self,
64
64
  graph: Graph,
65
- nodes_list: List[BaseNode]) -> Graph:
65
+ nodes_list: List[Node]) -> Graph:
66
66
  """
67
67
  Scale activation threshold for input layers, if they are followed by linear nodes. We first
68
68
  scale their thresholds to a constrained threshold, and then fix it by scaling the linear op weights
@@ -21,7 +21,7 @@ from model_compression_toolkit import common
21
21
  from model_compression_toolkit.common.graph.base_graph import Graph
22
22
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, EdgeMatcher, \
23
23
  NodeFrameworkAttrMatcher
24
- from model_compression_toolkit.common.graph.base_node import BaseNode
24
+ from model_compression_toolkit.common.graph.node import Node
25
25
  from model_compression_toolkit.keras.constants import LINEAR, ACTIVATION
26
26
 
27
27
 
@@ -55,7 +55,7 @@ class MarkActivation(common.BaseSubstitution):
55
55
 
56
56
  def substitute(self,
57
57
  graph: Graph,
58
- edge: Tuple[BaseNode, BaseNode]) -> Graph:
58
+ edge: Tuple[Node, Node]) -> Graph:
59
59
  """
60
60
  Mark the first node in an edge that should not be quantized as so.
61
61
  This can be done due to the following reasons:
@@ -69,5 +69,5 @@ class MarkActivation(common.BaseSubstitution):
69
69
  Graph after applying the substitution.
70
70
  """
71
71
 
72
- edge[0].activation_quantization_cfg.enable_activation_quantization = False
72
+ edge[0].output_quantization = False
73
73
  return graph
@@ -20,11 +20,12 @@ import numpy as np
20
20
  from tensorflow.keras.layers import ReLU, Activation, DepthwiseConv2D, Conv2DTranspose, Conv2D, Dense
21
21
 
22
22
  from model_compression_toolkit import common
23
- from model_compression_toolkit.common import Graph, BaseNode
23
+ from model_compression_toolkit.common import FrameworkInfo, Graph, Node
24
24
  from model_compression_toolkit.common.constants import THRESHOLD
25
25
  from model_compression_toolkit.common.graph.graph_matchers import NodeOperationMatcher, WalkMatcher, \
26
26
  NodeFrameworkAttrMatcher
27
- from model_compression_toolkit.common.collectors.statistics_collector import is_number
27
+ from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
28
+ from model_compression_toolkit.common.statistics_collector import is_number
28
29
  from model_compression_toolkit.keras.constants import KERNEL, BIAS, ACTIVATION, RELU_MAX_VALUE
29
30
  from model_compression_toolkit.keras.constants import RELU
30
31
 
@@ -60,7 +61,7 @@ class ReLUBoundCorrection(common.BaseSubstitution):
60
61
 
61
62
  def substitute(self,
62
63
  graph: Graph,
63
- nodes_list: List[BaseNode]) -> Graph:
64
+ nodes_list: List[Node]) -> Graph:
64
65
  """
65
66
  Transform a list of nodes in a graph to use the entire constrained quantized range.
66
67
  This is done by scaling two linear nodes with a non-linearity between them, if the non-linearity