mct-nightly 2.0.0.20240508.145608__py3-none-any.whl → 2.0.0.20240509.406__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240508.145608
3
+ Version: 2.0.0.20240509.406
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -72,9 +72,6 @@ for hands-on learning. For example:
72
72
  * [Post training quantization with PyTorch](tutorials/notebooks/pytorch/ptq/example_pytorch_quantization_mnist.ipynb)
73
73
  * [Data Generation for ResNet18 with PyTorch](tutorials/notebooks/pytorch/data_generation/example_pytorch_data_generation.ipynb).
74
74
 
75
- Additionally, for quick quantization of a variety of models from well-known collections,
76
- visit the [quick-start page](tutorials/quick_start/README.md) and the
77
- [results CSV](tutorials/quick_start/results/model_quantization_results.csv).
78
75
 
79
76
  ### Supported Versions
80
77
 
@@ -1,5 +1,5 @@
1
- model_compression_toolkit/__init__.py,sha256=BUMAk1BC9peC1TCFqoOqMrOyFT_LMr7kLtuYqQZTp8A,1573
2
- model_compression_toolkit/constants.py,sha256=b63Jk_bC7VXEX3Qn9TZ3wUvrNKD8Mkz8zIuayoyF5eU,3828
1
+ model_compression_toolkit/__init__.py,sha256=ChRp1KQR5GCWN6Py2srdeALrqbdAT6VHfoe2LCWSqJc,1573
2
+ model_compression_toolkit/constants.py,sha256=yIJyJ-e1WrDeKD9kG15qkqfYnoj7J1J2CxnJDt008ik,3756
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
5
  model_compression_toolkit/metadata.py,sha256=IyoON37lBv3TI0rZGCP4K5t3oYI4TOmYy-LRXOwHGpE,1136
@@ -145,7 +145,7 @@ model_compression_toolkit/core/common/substitutions/weights_activation_split.py,
145
145
  model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
146
146
  model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
147
147
  model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
148
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=lkQ5B3YKcojPfNdkPCZ9ViJ0zMOSsWmZ-ELmiBcNcqI,22510
148
+ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=4E4ZXZmqusGIJ4XQNH8FFt07htAHgT3gy5E7wPIaVBI,21951
149
149
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
150
150
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
151
151
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
@@ -266,7 +266,7 @@ model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZE
266
266
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
267
267
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
268
268
  model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
269
- model_compression_toolkit/data_generation/__init__.py,sha256=S8pRUqlRvpM5AFHpFWs3zb0H0rtY5nUwmeCQij01oi4,1507
269
+ model_compression_toolkit/data_generation/__init__.py,sha256=R_RnB8Evj4uq0WKiPWvBWfeePrbake7Z03ugJgK7jLo,1466
270
270
  model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
271
271
  model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
272
272
  model_compression_toolkit/data_generation/common/data_generation.py,sha256=fccGG6cTMScZwjnJDQKMugOLdgm9dKg5rRfcBD4EFYQ,6415
@@ -292,7 +292,7 @@ model_compression_toolkit/data_generation/pytorch/constants.py,sha256=QWyreMImcf
292
292
  model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=6g7OpOuO3cU4TIuelaRjBKpCPgiMbe1a3iy9bZtdZUo,6617
293
293
  model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=sO9tA03nIaeYnzOL4Egec5sVcSGU8H8k9-nNjhaLEbk,9690
294
294
  model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=AjYsO-lm06JOUMoKkS6VbyF4O_l_ffWXrgamqJm1ofE,19085
295
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=UGX0J0lU1bY4ZI6qE1K0AnFWsDFs3clYPBC4GZf9KxA,21219
295
+ model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=rZ-4YAcgEc9qZEs5FrK0OJaNtSsQC57Y61UdbXbQcE4,20937
296
296
  model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
297
297
  model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
298
298
  model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=i3ePEI8xDE3xZEtmzT5lCkLn9wpObUi_OgqnVDf7nj8,2597
@@ -315,9 +315,9 @@ model_compression_toolkit/exporter/model_exporter/keras/mctq_keras_exporter.py,s
315
315
  model_compression_toolkit/exporter/model_exporter/pytorch/__init__.py,sha256=uZ2RigbY9O2PJ0Il8wPpS_s7frgg9WUGd_SHeKGyl1A,699
316
316
  model_compression_toolkit/exporter/model_exporter/pytorch/base_pytorch_exporter.py,sha256=UPVkEUQCMZ4Lld6CRnEOPEmlfe5vcQZG0Q3FwRBodD4,4021
317
317
  model_compression_toolkit/exporter/model_exporter/pytorch/export_serialization_format.py,sha256=bPevy6OBqng41PqytBR55e6cBEuyrUS0H8dWX4zgjQ4,967
318
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py,sha256=b-qC60LiRtc52gIXdUbrdTBKUgCIaResDLXFE8zt_F4,6732
318
+ model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py,sha256=r2pOWFK-mSG8OzRiKGVOG4skzX0ZiM0eiRuBsL-ThoI,6067
319
319
  model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_torchscript_pytorch_exporter.py,sha256=ksWV2A-Njo-wAxQ_Ye2sLIZXBWJ_WNyjT7-qFFwvV2o,2897
320
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py,sha256=5G9dikFY4A66XVjpaOWVWX81Qr6ZdwnoyBzFDL_abi8,6242
320
+ model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py,sha256=yz5dPMX5r1d9LJV4rYFS1pXqCbVUxvUmV4LELWcRinQ,6350
321
321
  model_compression_toolkit/exporter/model_wrapper/__init__.py,sha256=7CF2zvpTrIEm8qnbuHnLZyTZkwBBxV24V8QA0oxGbh0,1187
322
322
  model_compression_toolkit/exporter/model_wrapper/fw_agnostic/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
323
323
  model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quantizers.py,sha256=Bd3QhAR__YC9Xmobd5qHv9ofh_rPn_eTFV0sXizcBnY,2297
@@ -483,8 +483,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
483
483
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
484
484
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
485
485
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
486
- mct_nightly-2.0.0.20240508.145608.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
487
- mct_nightly-2.0.0.20240508.145608.dist-info/METADATA,sha256=5XKDpGXcxCMKbL1iHeQRupmzRYmVT1gTB1hAUPPSTJU,18798
488
- mct_nightly-2.0.0.20240508.145608.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
489
- mct_nightly-2.0.0.20240508.145608.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
490
- mct_nightly-2.0.0.20240508.145608.dist-info/RECORD,,
486
+ mct_nightly-2.0.0.20240509.406.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
487
+ mct_nightly-2.0.0.20240509.406.dist-info/METADATA,sha256=CnhcTwwsr7Ks92s0saVvRr0npvEJoGDpeUCCH3OcWfU,18559
488
+ mct_nightly-2.0.0.20240509.406.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
489
+ mct_nightly-2.0.0.20240509.406.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
490
+ mct_nightly-2.0.0.20240509.406.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240508.145608"
30
+ __version__ = "2.0.0.20240509.000406"
@@ -20,7 +20,6 @@ TENSORFLOW = 'tensorflow'
20
20
  PYTORCH = 'pytorch'
21
21
  FOUND_TF = importlib.util.find_spec(TENSORFLOW) is not None
22
22
  FOUND_TORCH = importlib.util.find_spec("torch") is not None
23
- FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
24
23
  FOUND_ONNX = importlib.util.find_spec("onnx") is not None
25
24
  FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
26
25
  FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None
@@ -72,19 +72,10 @@ def get_node_properties(node_dict_to_log: dict,
72
72
  # Create protobuf for the node's output shapes
73
73
  if output_shapes is not None:
74
74
  tshape_protos = []
75
- is_tf_combined_non_max_suppression = len(output_shapes) == 1 and 'function' in node_dict_to_log and node_dict_to_log['function'] == 'image.combined_non_max_suppression'
76
-
77
- if is_tf_combined_non_max_suppression:
78
- combined_nms_output = output_shapes[0]
79
- output_shapes = [combined_nms_output.nmsed_boxes,
80
- combined_nms_output.nmsed_scores,
81
- combined_nms_output.nmsed_classes,
82
- combined_nms_output.valid_detections]
83
-
84
75
  for output_shape in output_shapes: # create protobuf for each output shape
85
76
  proto_dims_list = []
86
77
  for dim in output_shape:
87
- proto_dims_list.append(TensorShapeProto.Dim(size=dim)) # dim shold ne an integer
78
+ proto_dims_list.append(TensorShapeProto.Dim(size=dim))
88
79
  tshape_proto = TensorShapeProto(dim=proto_dims_list)
89
80
  tshape_protos.append(tshape_proto)
90
81
  node_properties['_output_shapes'] = AttrValue(list=AttrValue.ListValue(shape=tshape_protos))
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF, FOUND_TORCHVISION
16
+ from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF
17
17
  from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
18
18
  from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType
19
19
 
@@ -21,6 +21,6 @@ if FOUND_TF:
21
21
  from model_compression_toolkit.data_generation.keras.keras_data_generation import (
22
22
  keras_data_generation_experimental, get_keras_data_generation_config)
23
23
 
24
- if FOUND_TORCH and FOUND_TORCHVISION:
24
+ if FOUND_TORCH:
25
25
  from model_compression_toolkit.data_generation.pytorch.pytorch_data_generation import (
26
26
  pytorch_data_generation_experimental, get_pytorch_data_generation_config)
@@ -17,7 +17,7 @@ from typing import Callable, Any, Tuple, List
17
17
 
18
18
  from tqdm import tqdm
19
19
 
20
- from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TORCHVISION
20
+ from model_compression_toolkit.constants import FOUND_TORCH
21
21
  from model_compression_toolkit.core.pytorch.utils import set_model
22
22
  from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
23
23
  from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
@@ -44,7 +44,7 @@ from model_compression_toolkit.data_generation.pytorch.optimization_functions.sc
44
44
  from model_compression_toolkit.data_generation.pytorch.optimization_utils import PytorchImagesOptimizationHandler
45
45
  from model_compression_toolkit.logger import Logger
46
46
 
47
- if FOUND_TORCH and FOUND_TORCHVISION:
47
+ if FOUND_TORCH:
48
48
  # Importing necessary libraries
49
49
  import torch
50
50
  from torch import Tensor
@@ -354,9 +354,10 @@ else:
354
354
  # If torch is not installed,
355
355
  # we raise an exception when trying to use these functions.
356
356
  def get_pytorch_data_generation_config(*args, **kwargs):
357
- msg = f"torch and torchvision must be installed to use get_pytorch_data_generation_config. " + ("" if FOUND_TORCH else "'torch' package is missing. ") + ("" if FOUND_TORCHVISION else "'torchvision' package is missing. ") # pragma: no cover
358
- Logger.critical(msg) # pragma: no cover
357
+ Logger.critical('PyTorch must be installed to use get_pytorch_data_generation_config. '
358
+ "The 'torch' package is missing.") # pragma: no cover
359
+
359
360
 
360
361
  def pytorch_data_generation_experimental(*args, **kwargs):
361
- msg = f"torch and torchvision must be installed to use pytorch_data_generation_experimental. " + ("" if FOUND_TORCH else "'torch' package is missing. ") + ("" if FOUND_TORCHVISION else "'torchvision' package is missing. ") # pragma: no cover
362
- Logger.critical(msg) # pragma: no cover
362
+ Logger.critical("PyTorch must be installed to use 'pytorch_data_generation_experimental'. "
363
+ "The 'torch' package is missing.") # pragma: no cover
@@ -16,124 +16,117 @@ from typing import Callable
16
16
  from io import BytesIO
17
17
 
18
18
  import torch.nn
19
+ import onnx
19
20
 
20
21
  from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21
- from model_compression_toolkit.constants import FOUND_ONNX
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
24
24
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
25
25
  from mct_quantizers import pytorch_quantizers
26
+ from mct_quantizers.pytorch.metadata import add_onnx_metadata
27
+
28
+ DEFAULT_ONNX_OPSET_VERSION=15
29
+
30
+
31
+ class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
32
+ """
33
+ Exporter for fakely-quant PyTorch models.
34
+ The exporter expects to receive an exportable model (where each layer's full quantization parameters
35
+ can be retrieved), and convert it into a fakely-quant model (namely, weights that are in fake-quant
36
+ format) and fake-quant layers for the activations.
37
+ """
38
+
39
+ def __init__(self,
40
+ model: torch.nn.Module,
41
+ is_layer_exportable_fn: Callable,
42
+ save_model_path: str,
43
+ repr_dataset: Callable,
44
+ use_onnx_custom_quantizer_ops: bool = False,
45
+ onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION):
46
+ """
47
+
48
+ Args:
49
+ model: Model to export.
50
+ is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
51
+ save_model_path: Path to save the exported model.
52
+ repr_dataset: Representative dataset (needed for creating torch script).
53
+ use_onnx_custom_quantizer_ops: Whether to export quantizers custom ops in ONNX or not.
54
+ onnx_opset_version: ONNX opset version to use for exported ONNX model.
55
+ """
26
56
 
57
+ super().__init__(model,
58
+ is_layer_exportable_fn,
59
+ save_model_path,
60
+ repr_dataset)
27
61
 
28
- if FOUND_ONNX:
29
- import onnx
30
- from mct_quantizers.pytorch.metadata import add_onnx_metadata
62
+ self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
63
+ self._onnx_opset_version = onnx_opset_version
31
64
 
32
- class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
65
+ def export(self) -> None:
33
66
  """
34
- Exporter for fakely-quant PyTorch models.
35
- The exporter expects to receive an exportable model (where each layer's full quantization parameters
36
- can be retrieved), and convert it into a fakely-quant model (namely, weights that are in fake-quant
37
- format) and fake-quant layers for the activations.
67
+ Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
68
+ (namely, weights that are in fake-quant format) and fake-quant layers for the activations.
69
+
70
+ Returns:
71
+ Fake-quant PyTorch model.
38
72
  """
73
+ for layer in self.model.children():
74
+ self.is_layer_exportable_fn(layer)
75
+
76
+ # Set forward that is used during onnx export.
77
+ # If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
78
+ # the custom implementation when exporting the operator into onnx model. If not, it removes the
79
+ # wraps and quantizes the ops in place (for weights, for activation torch quantization function is
80
+ # exported since it's used during forward).
81
+ if self._use_onnx_custom_quantizer_ops:
82
+ self._enable_onnx_custom_ops_export()
83
+ else:
84
+ self._substitute_fully_quantized_model()
85
+
86
+ if self._use_onnx_custom_quantizer_ops:
87
+ Logger.info(f"Exporting onnx model with MCTQ quantizers: {self.save_model_path}")
88
+ else:
89
+ Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")
90
+
91
+ model_input = to_torch_tensor(next(self.repr_dataset())[0])
92
+
93
+ if hasattr(self.model, 'metadata'):
94
+ onnx_bytes = BytesIO()
95
+ torch.onnx.export(self.model,
96
+ model_input,
97
+ onnx_bytes,
98
+ opset_version=self._onnx_opset_version,
99
+ verbose=False,
100
+ input_names=['input'],
101
+ output_names=['output'],
102
+ dynamic_axes={'input': {0: 'batch_size'},
103
+ 'output': {0: 'batch_size'}})
104
+ onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
105
+ onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
106
+ onnx.save_model(onnx_model, self.save_model_path)
107
+ else:
108
+ torch.onnx.export(self.model,
109
+ model_input,
110
+ self.save_model_path,
111
+ opset_version=self._onnx_opset_version,
112
+ verbose=False,
113
+ input_names=['input'],
114
+ output_names=['output'],
115
+ dynamic_axes={'input': {0: 'batch_size'},
116
+ 'output': {0: 'batch_size'}})
117
+
118
+ def _enable_onnx_custom_ops_export(self):
119
+ """
120
+ Enable the custom implementation forward in quantizers, so it is exported
121
+ with custom quantizers.
122
+ """
123
+
124
+ for n, m in self.model.named_modules():
125
+ if isinstance(m, PytorchActivationQuantizationHolder):
126
+ assert isinstance(m.activation_holder_quantizer, pytorch_quantizers.BasePyTorchInferableQuantizer)
127
+ m.activation_holder_quantizer.enable_custom_impl()
39
128
 
40
- def __init__(self,
41
- model: torch.nn.Module,
42
- is_layer_exportable_fn: Callable,
43
- save_model_path: str,
44
- repr_dataset: Callable,
45
- onnx_opset_version: int,
46
- use_onnx_custom_quantizer_ops: bool = False,):
47
- """
48
-
49
- Args:
50
- model: Model to export.
51
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
52
- save_model_path: Path to save the exported model.
53
- repr_dataset: Representative dataset (needed for creating torch script).
54
- onnx_opset_version: ONNX opset version to use for exported ONNX model.
55
- use_onnx_custom_quantizer_ops: Whether to export quantizers custom ops in ONNX or not.
56
- """
57
-
58
- super().__init__(model,
59
- is_layer_exportable_fn,
60
- save_model_path,
61
- repr_dataset)
62
-
63
- self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
64
- self._onnx_opset_version = onnx_opset_version
65
-
66
- def export(self) -> None:
67
- """
68
- Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
69
- (namely, weights that are in fake-quant format) and fake-quant layers for the activations.
70
-
71
- Returns:
72
- Fake-quant PyTorch model.
73
- """
74
- for layer in self.model.children():
75
- self.is_layer_exportable_fn(layer)
76
-
77
- # Set forward that is used during onnx export.
78
- # If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
79
- # the custom implementation when exporting the operator into onnx model. If not, it removes the
80
- # wraps and quantizes the ops in place (for weights, for activation torch quantization function is
81
- # exported since it's used during forward).
82
- if self._use_onnx_custom_quantizer_ops:
83
- self._enable_onnx_custom_ops_export()
84
- else:
85
- self._substitute_fully_quantized_model()
86
-
87
- if self._use_onnx_custom_quantizer_ops:
88
- Logger.info(f"Exporting onnx model with MCTQ quantizers: {self.save_model_path}")
89
- else:
90
- Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")
91
-
92
- model_input = to_torch_tensor(next(self.repr_dataset())[0])
93
-
94
- if hasattr(self.model, 'metadata'):
95
- onnx_bytes = BytesIO()
96
- torch.onnx.export(self.model,
97
- model_input,
98
- onnx_bytes,
99
- opset_version=self._onnx_opset_version,
100
- verbose=False,
101
- input_names=['input'],
102
- output_names=['output'],
103
- dynamic_axes={'input': {0: 'batch_size'},
104
- 'output': {0: 'batch_size'}})
105
- onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
106
- onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
107
- onnx.save_model(onnx_model, self.save_model_path)
108
- else:
109
- torch.onnx.export(self.model,
110
- model_input,
111
- self.save_model_path,
112
- opset_version=self._onnx_opset_version,
113
- verbose=False,
114
- input_names=['input'],
115
- output_names=['output'],
116
- dynamic_axes={'input': {0: 'batch_size'},
117
- 'output': {0: 'batch_size'}})
118
-
119
- def _enable_onnx_custom_ops_export(self):
120
- """
121
- Enable the custom implementation forward in quantizers, so it is exported
122
- with custom quantizers.
123
- """
124
-
125
- for n, m in self.model.named_modules():
126
- if isinstance(m, PytorchActivationQuantizationHolder):
127
- assert isinstance(m.activation_holder_quantizer, pytorch_quantizers.BasePyTorchInferableQuantizer)
128
- m.activation_holder_quantizer.enable_custom_impl()
129
-
130
- if isinstance(m, PytorchQuantizationWrapper):
131
- for wq in m.weights_quantizers.values():
132
- assert isinstance(wq, pytorch_quantizers.BasePyTorchInferableQuantizer)
133
- wq.enable_custom_impl()
134
-
135
- else:
136
- def FakelyQuantONNXPyTorchExporter(*args, **kwargs):
137
- Logger.critical('Installing onnx is mandatory '
138
- 'when using FakelyQuantONNXPyTorchExporter. '
139
- 'Could not find onnx package.') # pragma: no cover
129
+ if isinstance(m, PytorchQuantizationWrapper):
130
+ for wq in m.weights_quantizers.values():
131
+ assert isinstance(wq, pytorch_quantizers.BasePyTorchInferableQuantizer)
132
+ wq.enable_custom_impl()
@@ -14,20 +14,20 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
 
17
- from model_compression_toolkit.constants import FOUND_TORCH, FOUND_ONNX
17
+ from model_compression_toolkit.constants import FOUND_TORCH
18
18
  from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
19
19
  from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
20
20
  PytorchExportSerializationFormat
21
21
  from model_compression_toolkit.logger import Logger
22
-
23
-
24
- DEFAULT_ONNX_OPSET_VERSION = 15
22
+ from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
25
23
 
26
24
 
27
25
  if FOUND_TORCH:
28
26
  import torch.nn
29
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
30
- from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
27
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
28
+ FakelyQuantONNXPyTorchExporter, DEFAULT_ONNX_OPSET_VERSION
29
+ from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
30
+ FakelyQuantTorchScriptPyTorchExporter
31
31
  from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
32
32
 
33
33
  supported_serialization_quantization_export_dict = {