mct-nightly 1.8.0.20052023.post401__py3-none-any.whl → 1.8.0.20230610.post356__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/METADATA +10 -7
  2. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/RECORD +68 -115
  3. model_compression_toolkit/__init__.py +23 -3
  4. model_compression_toolkit/core/common/framework_info.py +1 -1
  5. model_compression_toolkit/core/keras/back2framework/instance_builder.py +16 -9
  6. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +8 -34
  7. model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +5 -1
  8. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +103 -28
  9. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py +39 -44
  10. model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py +1 -1
  11. model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py +20 -18
  12. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +3 -3
  13. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +36 -9
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +4 -4
  16. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +24 -32
  17. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +31 -8
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +5 -5
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +34 -8
  20. model_compression_toolkit/gptq/keras/gptq_training.py +15 -16
  21. model_compression_toolkit/gptq/keras/graph_info.py +2 -2
  22. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -5
  23. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +5 -7
  24. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +6 -6
  26. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -7
  27. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +6 -6
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +30 -10
  29. model_compression_toolkit/gptq/pytorch/graph_info.py +5 -2
  30. model_compression_toolkit/gptq/pytorch/quantization_facade.py +4 -2
  31. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +4 -4
  32. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +5 -7
  33. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -1
  34. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +7 -7
  35. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +7 -8
  36. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +7 -8
  37. model_compression_toolkit/qat/common/__init__.py +2 -1
  38. model_compression_toolkit/qat/common/qat_config.py +2 -2
  39. model_compression_toolkit/qat/keras/quantization_facade.py +18 -8
  40. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +1 -1
  41. model_compression_toolkit/qat/keras/quantizer/quantization_builder.py +11 -11
  42. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +11 -12
  43. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +12 -13
  44. model_compression_toolkit/qat/pytorch/quantization_facade.py +27 -16
  45. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +2 -2
  46. model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py +31 -4
  47. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +10 -9
  48. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +11 -10
  49. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +2 -1
  50. model_compression_toolkit/target_platform_capabilities/target_platform/op_quantization_config.py +1 -25
  51. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py → trainable_infrastructure/__init__.py} +3 -10
  52. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/base_trainable_quantizer.py +3 -3
  53. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizer_config.py +1 -1
  54. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/get_quantizers.py +3 -3
  55. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/base_keras_quantizer.py +4 -4
  56. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/config_serialization.py +2 -2
  57. model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure}/keras/load_model.py +16 -23
  58. model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/pytorch/base_pytorch_quantizer.py +3 -3
  59. model_compression_toolkit/quantizers_infrastructure/__init__.py +0 -23
  60. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/base_inferable_quantizer.py +0 -87
  61. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/constants.py +0 -46
  62. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +0 -31
  63. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +0 -53
  64. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/quant_utils.py +0 -49
  65. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/activation_quantization_holder.py +0 -147
  66. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +0 -345
  67. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizer_utils.py +0 -85
  68. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +0 -27
  69. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  70. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -148
  71. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -65
  72. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -86
  73. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -111
  74. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/base_keras_inferable_quantizer.py +0 -56
  75. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  76. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -79
  77. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -179
  78. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -67
  79. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -87
  80. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -163
  81. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/validation_functions.py +0 -66
  82. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/__init__.py +0 -14
  83. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +0 -269
  84. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizer_utils.py +0 -152
  85. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +0 -35
  86. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/__init__.py +0 -14
  87. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_lut_pot_inferable_quantizer.py +0 -96
  88. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_pot_inferable_quantizer.py +0 -62
  89. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_symmetric_inferable_quantizer.py +0 -83
  90. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/activation_inferable_quantizers/activation_uniform_inferable_quantizer.py +0 -100
  91. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_lut_symmetric_inferable_quantizer.py +0 -95
  92. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_pytorch_inferable_quantizer.py +0 -48
  93. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_symmetric_inferable_quantizer.py +0 -70
  94. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/base_uniform_inferable_quantizer.py +0 -57
  95. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +0 -26
  96. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/__init__.py +0 -14
  97. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_pot_inferable_quantizer.py +0 -77
  98. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_lut_symmetric_inferable_quantizer.py +0 -106
  99. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_pot_inferable_quantizer.py +0 -66
  100. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py +0 -104
  101. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/weights_inferable_quantizers/weights_uniform_inferable_quantizer.py +0 -109
  102. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/__init__.py +0 -14
  103. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/__init__.py +0 -14
  104. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/__init__.py +0 -14
  105. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/__init__.py +0 -14
  106. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/LICENSE.md +0 -0
  107. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/WHEEL +0 -0
  108. {mct_nightly-1.8.0.20052023.post401.dist-info → mct_nightly-1.8.0.20230610.post356.dist-info}/top_level.txt +0 -0
  109. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure → trainable_infrastructure/common}/__init__.py +0 -0
  110. /model_compression_toolkit/{quantizers_infrastructure → trainable_infrastructure/common}/constants.py +0 -0
  111. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/quant_utils.py +0 -0
  112. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/common/trainable_quantizer_config.py +0 -0
  113. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/common → trainable_infrastructure/keras}/__init__.py +0 -0
  114. /model_compression_toolkit/{quantizers_infrastructure/trainable_infrastructure → trainable_infrastructure}/keras/quantizer_utils.py +0 -0
  115. /model_compression_toolkit/{quantizers_infrastructure/inferable_infrastructure/keras → trainable_infrastructure/pytorch}/__init__.py +0 -0
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import abstractmethod
16
+ from functools import partial
16
17
  from typing import Tuple, Any, Dict, List, Union, Callable
17
18
 
18
19
  import torch
@@ -30,6 +31,7 @@ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAUL
30
31
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder, BufferHolder
31
32
  from model_compression_toolkit.core.pytorch.utils import get_working_device
32
33
  from model_compression_toolkit.core.pytorch.constants import BUFFER
34
+ from mct_quantizers.common.constants import ACTIVATION_HOLDER_QUANTIZER
33
35
 
34
36
 
35
37
  def _build_input_tensors_list(node: BaseNode,
@@ -66,7 +68,7 @@ def _run_operation(n: BaseNode,
66
68
  input_tensors: List,
67
69
  op_func: Any,
68
70
  quantize_node_activation_fn,
69
- is_wrapped: bool) -> Tuple[Union[List,torch.Tensor], Union[List,torch.Tensor]]:
71
+ use_activation_quantization: bool) -> Tuple[Union[List, torch.Tensor], Union[List, torch.Tensor]]:
70
72
  """
71
73
  Applying the layer (op_func) to the input tensors (input_tensors).
72
74
  If quantized is set to True, and the layer's corresponding node (n) has quantization
@@ -77,7 +79,7 @@ def _run_operation(n: BaseNode,
77
79
  input_tensors: List of Pytorch tensors that are the layer's inputs.
78
80
  op_func: Module/functional to apply to the input tensors.
79
81
  quantize_node_activation_fn: quantization function
80
- is_wrapped : Flag to indicate if layer is already quantization wrapped so no activation is needed
82
+ use_activation_quantization: Flag to indicate if we have an activation function.
81
83
  Returns:
82
84
  A tuple of Pytorch tensors. The Module/functional output tensors after applying the
83
85
  Module/functional to the input tensors.
@@ -92,10 +94,10 @@ def _run_operation(n: BaseNode,
92
94
 
93
95
  # Add a fake quant node if the node has an activation threshold.
94
96
  out_tensors_of_n = out_tensors_of_n_float
95
- if n.is_activation_quantization_enabled() and not is_wrapped:
97
+ if use_activation_quantization:
96
98
  if isinstance(out_tensors_of_n_float, list):
97
99
  out_tensors_of_n_float = torch.cat(out_tensors_of_n_float, dim=0)
98
- out_tensors_of_n = quantize_node_activation_fn(n, out_tensors_of_n_float)
100
+ out_tensors_of_n = quantize_node_activation_fn(out_tensors_of_n_float)
99
101
 
100
102
  return out_tensors_of_n, out_tensors_of_n_float
101
103
 
@@ -145,7 +147,8 @@ class PytorchModel(torch.nn.Module):
145
147
  append2output: List[Any] = None,
146
148
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
147
149
  return_float_outputs: bool = False,
148
- wrapper: Callable = identity_wrapper):
150
+ wrapper: Callable = None,
151
+ get_activation_quantizer_holder_fn: Callable = None):
149
152
  """
150
153
  Construct a Pytorch model.
151
154
 
@@ -155,17 +158,31 @@ class PytorchModel(torch.nn.Module):
155
158
  fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
156
159
  return_float_outputs: Whether the model returns float tensors or not.
157
160
  wrapper: A function wrapper Pytorch Layers.
161
+ get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
162
+
158
163
  """
159
164
  super(PytorchModel, self).__init__()
160
165
  self.graph = graph
161
166
  self.node_sort = list(topological_sort(graph))
162
- self.nodes_dict = {}
167
+ self.node_to_activation_quantization_holder = {}
163
168
  self.append2output = append2output
164
169
  self.return_float_outputs = return_float_outputs
165
170
  self.fw_info = fw_info
166
171
  self.wrapper = wrapper
172
+ self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
167
173
  self._add_modules()
168
174
 
175
+ # todo: Move to parent class BaseModelBuilder
176
+ @property
177
+ def use_activation_holder_during_model_building(self) -> bool:
178
+ """
179
+ Returns: Whether or not the model builder uses a PytorchActivationQuantizationHolder during
180
+ model building (by adding it as a module when converting the graph to a Pytorch model).
181
+ If so - the model builder expects the activation quantizers not to be wrapped
182
+ in a PytorchQuantizeWrapper.
183
+ """
184
+ return self.get_activation_quantizer_holder is not None
185
+
169
186
  @abstractmethod
170
187
  def _quantize_node_activations(self,
171
188
  node: BaseNode,
@@ -184,18 +201,50 @@ class PytorchModel(torch.nn.Module):
184
201
  raise NotImplemented(f'{self.__class__.__name__} '
185
202
  f'have to implement a method for quantization activation nodes.') # pragma: no cover
186
203
 
204
+ def wrap(self, node):
205
+ """
206
+ Wraps a node operation with a wrapper, if one is available.
207
+
208
+ Args:
209
+ node: node to wrap its operation.
210
+
211
+ Returns: the node's operation. If a wrapper is available, the operation is wrapped.
212
+ """
213
+ if isinstance(node, FunctionalNode):
214
+ if self.wrapper is None:
215
+ node_op = node.type
216
+ else:
217
+ node_op = self.wrapper(node, node.type)
218
+ else:
219
+ if self.wrapper is None or node.type == BufferHolder:
220
+ node_op = node_builder(node)
221
+ else:
222
+ node_op = self.wrapper(node, node_builder(node))
223
+ return node_op
224
+
187
225
  def _add_modules(self):
188
- for n in self.node_sort:
189
- if isinstance(n, FunctionalNode):
226
+ """
227
+ Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
228
+ """
229
+ for node in self.node_sort:
230
+ node_op = self.wrap(node)
231
+ if isinstance(node, FunctionalNode):
190
232
  # for functional layers
191
- setattr(self, n.name, self.wrapper(n, n.type))
233
+ setattr(self, node.name, node_op)
192
234
  else:
193
- if n.type == BufferHolder:
194
- self.add_module(n.name, node_builder(n))
195
- self.get_submodule(n.name). \
196
- register_buffer(n.name, torch.Tensor(n.get_weights_by_keys(BUFFER)).to(get_working_device()))
197
- else:
198
- self.add_module(n.name, self.wrapper(n, node_builder(n)))
235
+ self.add_module(node.name, node_op)
236
+ if node.type == BufferHolder:
237
+ self.get_submodule(node.name). \
238
+ register_buffer(node.name,
239
+ torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device()))
240
+
241
+ # Add activation quantization modules if an activation holder is configured for this node
242
+ if node.is_activation_quantization_enabled() and self.get_activation_quantizer_holder is not None:
243
+ activation_quantizer_holder = self.get_activation_quantizer_holder(node)
244
+ if activation_quantizer_holder is not None:
245
+ self.add_module(node.name + '_' + ACTIVATION_HOLDER_QUANTIZER, activation_quantizer_holder)
246
+ self.node_to_activation_quantization_holder.update(
247
+ {node.name: node.name + '_' + ACTIVATION_HOLDER_QUANTIZER})
199
248
 
200
249
  def forward(self,
201
250
  *args: Any) -> Any:
@@ -208,28 +257,28 @@ class PytorchModel(torch.nn.Module):
208
257
  node_to_output_tensors_dict = dict()
209
258
  node_to_output_tensors_dict_float = dict()
210
259
  configurable_nodes = self.graph.get_configurable_sorted_nodes_names()
211
- for n in self.node_sort:
212
- input_tensors = _build_input_tensors_list(n,
260
+ for node in self.node_sort:
261
+ input_tensors = _build_input_tensors_list(node,
213
262
  self.graph,
214
263
  args,
215
264
  node_to_output_tensors_dict)
216
265
 
217
- op_func = self._get_op_func(n, configurable_nodes)
266
+ op_func = self._get_op_func(node, configurable_nodes)
267
+ use_activation_quantization, activation_quantization_fn = self._get_activation_quantization_fn(node)
218
268
 
219
269
  # Run node operation and fetch outputs
220
- out_tensors_of_n, out_tensors_of_n_float = _run_operation(n,
270
+ out_tensors_of_n, out_tensors_of_n_float = _run_operation(node,
221
271
  input_tensors,
222
272
  op_func=op_func,
223
- quantize_node_activation_fn=self._quantize_node_activations,
224
- is_wrapped=self.wrapper is not identity_wrapper)
273
+ quantize_node_activation_fn=activation_quantization_fn,
274
+ use_activation_quantization=use_activation_quantization)
225
275
 
226
276
  if isinstance(out_tensors_of_n, list):
227
- node_to_output_tensors_dict.update({n: out_tensors_of_n})
228
- node_to_output_tensors_dict_float.update({n: out_tensors_of_n_float})
277
+ node_to_output_tensors_dict.update({node: out_tensors_of_n})
278
+ node_to_output_tensors_dict_float.update({node: out_tensors_of_n_float})
229
279
  else:
230
- node_to_output_tensors_dict.update({n: [out_tensors_of_n]})
231
- node_to_output_tensors_dict_float.update({n: [out_tensors_of_n_float]})
232
-
280
+ node_to_output_tensors_dict.update({node: [out_tensors_of_n]})
281
+ node_to_output_tensors_dict_float.update({node: [out_tensors_of_n_float]})
233
282
 
234
283
  if self.append2output:
235
284
  outputs = _generate_outputs(self.append2output,
@@ -256,6 +305,28 @@ class PytorchModel(torch.nn.Module):
256
305
  """
257
306
  return getattr(self, node.name)
258
307
 
308
+ def _get_activation_quantization_fn(self, node) -> Tuple[bool, bool, Callable]:
309
+ """
310
+ Get activation quantization parameters for this node.
311
+
312
+ Args:
313
+ node: Node from which to extract the activation quantization parameters.
314
+
315
+ Returns: Flag to indicate if we quantize activations, flag to indicate if we quantize activations
316
+ using a quantization holder and a quantization function to use for the node's activations.
317
+ """
318
+ activation_quantization_holder = self.node_to_activation_quantization_holder.get(node.name)
319
+ use_activation_quantization = node.is_activation_quantization_enabled()
320
+ if use_activation_quantization:
321
+ if activation_quantization_holder is None:
322
+ activation_quantization_fn = partial(self._quantize_node_activations, node)
323
+ use_activation_quantization = self.wrapper is None
324
+ else:
325
+ activation_quantization_fn = getattr(self, activation_quantization_holder)
326
+ else:
327
+ activation_quantization_fn = None
328
+ return use_activation_quantization, activation_quantization_fn
329
+
259
330
 
260
331
  class PyTorchModelBuilder(BaseModelBuilder):
261
332
  """
@@ -267,7 +338,8 @@ class PyTorchModelBuilder(BaseModelBuilder):
267
338
  append2output=None,
268
339
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
269
340
  return_float_outputs: bool = False,
270
- wrapper: Callable = identity_wrapper):
341
+ wrapper: Callable = None,
342
+ get_activation_quantizer_holder_fn: Callable = None):
271
343
  """
272
344
 
273
345
  Args:
@@ -276,6 +348,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
276
348
  fw_info: Information about the specific framework of the model that is built.
277
349
  return_float_outputs: Whether the model returns float tensors or not.
278
350
  wrapper: A function wrapper Pytorch Layers.
351
+ get_activation_quantizer_holder_fn: Function to retrieve a quantization holder for a node.
279
352
  """
280
353
 
281
354
  super().__init__(graph,
@@ -284,6 +357,7 @@ class PyTorchModelBuilder(BaseModelBuilder):
284
357
  return_float_outputs)
285
358
 
286
359
  self.wrapper = wrapper
360
+ self.get_activation_quantizer_holder_fn = get_activation_quantizer_holder_fn
287
361
 
288
362
  def build_model(self) -> Tuple[PytorchModel, UserInformation]:
289
363
  """
@@ -294,4 +368,5 @@ class PyTorchModelBuilder(BaseModelBuilder):
294
368
  return PytorchModel(self.graph,
295
369
  self.append2output,
296
370
  return_float_outputs=self.return_float_outputs,
297
- wrapper=self.wrapper), self.graph.user_info
371
+ wrapper=self.wrapper,
372
+ get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder_fn), self.graph.user_info
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Dict, Callable
16
16
 
17
+ import keras
17
18
  import keras.models
18
19
  import keras.models
19
20
  import tensorflow as tf
@@ -22,9 +23,9 @@ from keras.engine.base_layer import Layer
22
23
  from model_compression_toolkit.logger import Logger
23
24
  from model_compression_toolkit.exporter.model_exporter.keras.base_keras_exporter import \
24
25
  BaseKerasExporter
25
- from model_compression_toolkit.quantizers_infrastructure import KerasQuantizationWrapper
26
-
26
+ from mct_quantizers import KerasQuantizationWrapper
27
27
 
28
+ layers = keras.layers
28
29
 
29
30
  class FakelyQuantKerasExporter(BaseKerasExporter):
30
31
  """
@@ -69,51 +70,45 @@ class FakelyQuantKerasExporter(BaseKerasExporter):
69
70
  Layer after unwrapping.
70
71
 
71
72
  """
72
- assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
73
+
74
+ # Assert each layer is exportable
75
+ self.is_layer_exportable_fn(layer)
73
76
 
74
77
  # If weights are quantized, use the quantized weight for the new built layer.
75
- if layer.is_weights_quantization:
76
- new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
77
- with tf.name_scope(new_layer.name):
78
- new_layer.build(layer.input_shape)
79
-
80
- # Build a list of the layer's new weights.
81
- weights_list = []
82
- # Go over weights, check if they should be quantized, and quantize if this is the case:
83
- for w in new_layer.weights:
84
- val = None
85
- for qw in layer.weights:
86
- if w.name in qw.name:
87
- # Use quantized weight if layer attribute should be quantized.
88
- # For example: check if 'kernel_0' is an attribute
89
- # that should be quantized. First, extract 'kernel' from variable name, check if the
90
- # quantize config contains this as an attribute for quantization. If so -
91
- # Take the quantized weight from the quantize_config and set it to the new layer.
92
- attribute_name = w.name.split('/')[-1].split(':')[0]
93
- if attribute_name in layer.weights_quantizers.keys():
94
- quantizer = layer.weights_quantizers.get(attribute_name)
95
- val = quantizer(qw)
96
- else:
97
- val = qw
98
- if val is None:
99
- Logger.error(f'Could not match weight name: {w.name}')
100
- weights_list.append(val)
101
-
102
- new_layer.set_weights(weights_list)
103
- new_layer.trainable = False
104
-
105
- # If activations are also quantized, wrap the layer back using ActivationQuantizeConfig
106
- # from original wrapper (weights wrapping is no longer needed).
107
- if layer.is_activation_quantization:
108
- new_layer = KerasQuantizationWrapper(layer=new_layer,
109
- activation_quantizers=layer.activation_quantizers)
110
-
111
- return new_layer
112
-
113
- # If this is a layer with activation quantization only, just return it
114
- # as activation quantization in the fake-quant case uses the wrapper for quantization.
115
- return layer
78
+ if isinstance(layer, KerasQuantizationWrapper):
79
+ if layer.is_weights_quantization:
80
+ new_layer = layer.layer.__class__.from_config(layer.layer.get_config())
81
+
82
+ # Build a list of the layer's new weights.
83
+ weights_list = []
84
+
85
+ # Create a list of weights for the new created layer
86
+ if isinstance(layer.layer, layers.DepthwiseConv2D):
87
+ weights_list.append(layer.get_quantized_weights()['depthwise_kernel'])
88
+ elif isinstance(layer.layer, (layers.Conv2D, layers.Dense, layers.Conv2DTranspose)):
89
+ weights_list.append(layer.get_quantized_weights()['kernel'])
90
+ else:
91
+ Logger.error(f'KerasQuantizationWrapper should wrap only DepthwiseConv2D, Conv2D, Dense'
92
+ f' and Conv2DTranspose layers but wrapped layer is {layer.layer}')
93
+
94
+ if layer.layer.bias is not None:
95
+ weights_list.append(layer.layer.bias)
96
+
97
+ # In order to add the weights of the layer, we need to build it. To build it
98
+ # we need to pass its input shape. Not every layer has input_shape since some
99
+ # layers may have multiple inputs with different input shapes (reused layers for
100
+ # example). For this reason, we take input shape at index 0 (any input shape
101
+ # should work since the weights are dependent only at some dimensions which have to
102
+ # be the same for all inputs).
103
+ with tf.name_scope(new_layer.name):
104
+ new_layer.build(layer.get_input_shape_at(0))
105
+
106
+ new_layer.set_weights(weights_list)
107
+ new_layer.trainable = False
108
+
109
+ return new_layer
116
110
 
111
+ return layer
117
112
 
118
113
  # clone each layer in the model and apply _unwrap_quantize_wrapper to layers wrapped with a QuantizeWrapper.
119
114
  self.exported_model = tf.keras.models.clone_model(self.model,
@@ -19,9 +19,9 @@ from typing import Callable
19
19
  import keras.models
20
20
  import tensorflow as tf
21
21
 
22
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.load_model import keras_load_quantized_model
23
22
  from model_compression_toolkit.logger import Logger
24
23
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
24
+ from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
25
25
 
26
26
 
27
27
  class FakelyQuantTFLiteExporter(FakelyQuantKerasExporter):
@@ -22,11 +22,9 @@ from keras import Sequential
22
22
  from keras.layers import Dense, Conv2D, Reshape
23
23
  from keras.models import clone_model
24
24
 
25
- from model_compression_toolkit import quantizers_infrastructure as qi
26
25
  from model_compression_toolkit.logger import Logger
27
26
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import FakelyQuantKerasExporter
28
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import \
29
- constants as keras_inferable_constants
27
+ from mct_quantizers import constants as keras_inferable_constants, KerasQuantizationWrapper
30
28
 
31
29
  BIAS_INITIALIZER = 'bias_initializer'
32
30
  BIAS_REGULARIZER = 'bias_regularizer'
@@ -50,6 +48,7 @@ KERNEL = 'kernel'
50
48
  CONV_KERNEL_CHANNEL_AXIS = 3
51
49
  CONV_INPUT_CHANNELS_DIM = 4
52
50
 
51
+
53
52
  class INT8TFLiteExporter(FakelyQuantKerasExporter):
54
53
  """
55
54
  Exporter for INT8 TFLite models.
@@ -75,7 +74,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
75
74
 
76
75
  self.exported_model = None
77
76
 
78
- def _get_pointwise_layer_to_replace_dense(self, wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
77
+ def _get_pointwise_layer_to_replace_dense(self, wrapped_layer: KerasQuantizationWrapper) -> keras.layers.Layer:
79
78
  # First we create a pointwise configuration based on the Dense layer's configuration
80
79
  dense_cfg = wrapped_layer.layer.get_config()
81
80
 
@@ -94,7 +93,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
94
93
 
95
94
  # Create the point-wise layer
96
95
  pw_layer = Conv2D(**pw_cfg)
97
- pw_layer.build(wrapped_layer.layer.input_shape)
96
+ pw_layer.build(wrapped_layer.input_shape)
98
97
 
99
98
  # Create and set the point-wise weights to assign
100
99
  dense_kernel = wrapped_layer.layer.kernel
@@ -110,7 +109,7 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
110
109
  pw_layer.set_weights(pw_weights)
111
110
 
112
111
  # Now that we have the point-wise to replace the dense layer,
113
- # we need to wrap it using qi.KerasQuantizationWrapper with a new
112
+ # we need to wrap it using KerasQuantizationWrapper with a new
114
113
  # relevant quantizers.
115
114
  # Create new kernel quantizer
116
115
  pw_kernel_quantizer_cfg = wrapped_layer.weights_quantizers[KERNEL].get_config()
@@ -121,8 +120,10 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
121
120
  # Unquantized weight to conv layer has 4 dimensions (unlike dense which varies)
122
121
  pw_kernel_quantizer_cfg[keras_inferable_constants.INPUT_RANK] = CONV_INPUT_CHANNELS_DIM
123
122
 
124
- assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD], np.ndarray), f'Expected to find threshold which is a numpy array, but found: {type(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])}'
125
- pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD] = list(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])
123
+ assert isinstance(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD],
124
+ np.ndarray), f'Expected to find threshold which is a numpy array, but found: {type(pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])}'
125
+ pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD] = list(
126
+ pw_kernel_quantizer_cfg[keras_inferable_constants.THRESHOLD])
126
127
 
127
128
  # Now that we have the point-wise quantizer we can instantiate it
128
129
  quantizer_class = type(wrapped_layer.weights_quantizers[KERNEL])
@@ -131,21 +132,21 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
131
132
  pw_weights_quantizers[KERNEL] = pw_quantizer
132
133
 
133
134
  # Wrap pw with the new quantizers (the activation is not affected thus we take the Dense quantizers)
134
- wrapped_pw = qi.KerasQuantizationWrapper(pw_layer,
135
- pw_weights_quantizers,
136
- wrapped_layer.activation_quantizers)
135
+ wrapped_pw = KerasQuantizationWrapper(pw_layer,
136
+ pw_weights_quantizers,
137
+ wrapped_layer.activation_quantizers)
137
138
 
138
139
  # Compute the shape that the input to the new layer should be reshaped into
139
140
  # Example: Dense kernel with the following shape (3, 20) expects to have input with the
140
141
  # next dimensions (BATCH_SIZE, x0, x1, ..., xn, 20).
141
142
  # Conv layer expects 4-rank input. Thus, the input is reshaped to (BATCH_SIZE, 1, x0*x1*...*xn, 20)
142
- dim = wrapped_layer.layer.input_shape[1:-1]
143
+ dim = wrapped_layer.input_shape[1:-1]
143
144
  target_shape = (1, int(np.prod(dim))) + (dense_kernel.get_shape()[0],)
144
145
 
145
146
  return Sequential([
146
147
  Reshape(target_shape=target_shape),
147
148
  wrapped_pw,
148
- Reshape(wrapped_layer.layer.output_shape[1:])
149
+ Reshape(wrapped_layer.output_shape[1:])
149
150
  ])
150
151
 
151
152
  def export(self) -> None:
@@ -153,17 +154,18 @@ class INT8TFLiteExporter(FakelyQuantKerasExporter):
153
154
  Export a fully quantized model to its int8 tflite model.
154
155
  """
155
156
 
156
- def _substitute_model(wrapped_layer: qi.KerasQuantizationWrapper) -> keras.layers.Layer:
157
+ def _substitute_model(layer_to_substitue: keras.layers.Layer) -> keras.layers.Layer:
157
158
  assert self.is_layer_exportable_fn(
158
- wrapped_layer), f'Layer {wrapped_layer.get_config()} did not pass validation'
159
+ layer_to_substitue), f'Layer {layer_to_substitue.get_config()} did not pass validation'
159
160
 
160
161
  # In order to support dense quantization using per-channel quantization (which is
161
162
  # unsupported in TFLITE int models) we substitute each dense layer to its equivalent
162
163
  # point-wise convolution.
163
- if isinstance(wrapped_layer.layer, Dense):
164
- return self._get_pointwise_layer_to_replace_dense(wrapped_layer)
164
+ if isinstance(layer_to_substitue, KerasQuantizationWrapper):
165
+ if isinstance(layer_to_substitue.layer, Dense):
166
+ return self._get_pointwise_layer_to_replace_dense(layer_to_substitue)
165
167
 
166
- return wrapped_layer
168
+ return layer_to_substitue
167
169
 
168
170
  # Transform the model to a new model that can be converted to int8 models.
169
171
  # For example: replace dense layers with point-wise layers (to support per-channel quantization)
@@ -21,8 +21,8 @@ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
21
21
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
22
22
  from packaging import version
23
23
 
24
- from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.constants import LAYER
24
+ from mct_quantizers import PytorchQuantizationWrapper
25
+ from mct_quantizers.common.constants import LAYER
26
26
 
27
27
  # ONNX opset version 16 is supported from PyTorch 1.12
28
28
  if version.parse(torch.__version__) < version.parse("1.12"):
@@ -68,7 +68,7 @@ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
68
68
  Fake-quant PyTorch model.
69
69
  """
70
70
  for layer in self.model.children():
71
- assert self.is_layer_exportable_fn(layer), f'Layer {layer.name} is not exportable.'
71
+ self.is_layer_exportable_fn(layer)
72
72
 
73
73
  model_input = to_torch_tensor(next(self.repr_dataset())[0])
74
74
 
@@ -57,7 +57,7 @@ class FakelyQuantTorchScriptPyTorchExporter(BasePyTorchExporter):
57
57
  Fake-quant PyTorch model.
58
58
  """
59
59
  for layer in self.model.children():
60
- assert self.is_layer_exportable_fn(layer), f'Layer {layer} is not exportable.'
60
+ self.is_layer_exportable_fn(layer)
61
61
 
62
62
  torch_traced = torch.jit.trace(self.model,
63
63
  to_torch_tensor(next(self.repr_dataset())),
@@ -12,36 +12,62 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple
16
15
 
17
-
18
- from model_compression_toolkit import quantizers_infrastructure as qi
16
+ from typing import Tuple, Callable
19
17
  from model_compression_toolkit.core import common
20
18
  from model_compression_toolkit.core.common import Graph
21
19
  from model_compression_toolkit.constants import FOUND_TF
22
20
  from model_compression_toolkit.core.common.user_info import UserInformation
23
21
  from model_compression_toolkit.logger import Logger
22
+ from mct_quantizers import KerasActivationQuantizationHolder
24
23
 
25
24
  if FOUND_TF:
26
25
  import tensorflow as tf
27
26
  from tensorflow.keras.layers import Layer
28
27
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
29
28
  from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizers import get_quantization_quantizers
29
+ from mct_quantizers import KerasQuantizationWrapper
30
30
 
31
31
  def _get_wrapper(node: common.BaseNode,
32
- layer: Layer) -> qi.KerasQuantizationWrapper:
32
+ layer: Layer) -> Layer:
33
33
  """
34
34
  A function which takes a computational graph node and a keras layer and perform the quantization wrapping
35
35
  Args:
36
- n: A node of mct graph.
36
+ node: A node of mct graph.
37
37
  layer: A keras layer
38
- include_activation_quantizers: Whether to use the wrapper for the activation quantizer or not
39
38
 
40
39
  Returns: Wrapped layer with weights quantizers and activation quantizers
41
40
 
42
41
  """
43
- weights_quantizers, activation_quantizers = get_quantization_quantizers(node)
44
- return qi.KerasQuantizationWrapper(layer, weights_quantizers, activation_quantizers)
42
+ weights_quantizers, _ = get_quantization_quantizers(node)
43
+ if len(weights_quantizers) > 0:
44
+ return KerasQuantizationWrapper(layer,
45
+ weights_quantizers)
46
+ return layer
47
+
48
+
49
+ def get_activation_quantizer_holder(node: common.BaseNode) -> Callable:
50
+ """
51
+ Retrieve a ActivationQuantizationHolder layer to use for activation quantization for a node.
52
+
53
+ Args:
54
+ node: Node to get ActivationQuantizationHolder to attach in its output.
55
+
56
+ Returns:
57
+ A ActivationQuantizationHolder layer for the node activation quantization.
58
+ """
59
+ _, activation_quantizers = get_quantization_quantizers(node)
60
+
61
+ # Holder by definition uses a single quantizer for the activation quantization
62
+ # thus we make sure this is the only possible case (unless it's a node with no activation
63
+ # quantization, which in this case has an empty list).
64
+ if len(activation_quantizers) == 1:
65
+ return KerasActivationQuantizationHolder(activation_quantizers[0])
66
+
67
+ Logger.error(
68
+ f'ActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
69
+ f'were found for node {node}')
70
+
45
71
 
46
72
 
47
73
  def get_exportable_keras_model(graph: Graph) -> Tuple[tf.keras.models.Model, UserInformation]:
@@ -57,7 +83,8 @@ if FOUND_TF:
57
83
  Exportable Keras model and user information.
58
84
  """
59
85
  exportable_model, user_info = KerasModelBuilder(graph=graph,
60
- wrapper=_get_wrapper).build_model()
86
+ wrapper=_get_wrapper,
87
+ get_activation_quantizer_holder_fn=get_activation_quantizer_holder).build_model()
61
88
  exportable_model.trainable = False
62
89
  return exportable_model, user_info
63
90
  else:
@@ -19,10 +19,10 @@ from model_compression_toolkit.constants import THRESHOLD, RANGE_MIN, RANGE_MAX,
19
19
 
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
22
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import QuantizationTarget
23
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.get_quantizers import get_inferable_quantizer_class
24
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers.base_keras_inferable_quantizer import BaseKerasInferableQuantizer
25
- from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.keras.quantizers import constants as qi_keras_consts
22
+ from mct_quantizers import QuantizationTarget
23
+ from mct_quantizers.common.get_quantizers import get_inferable_quantizer_class
24
+ from mct_quantizers.keras.quantizers import BaseKerasInferableQuantizer
25
+ from mct_quantizers import constants as qi_keras_consts
26
26
 
27
27
  def get_inferable_quantizer_kwargs(node: BaseNode,
28
28
  quantization_target: QuantizationTarget) -> Dict[str, Any]: