mct-nightly 2.2.0.20250105.534__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
  3. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/__init__.py +1 -1
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
  7. model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
  8. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
  9. model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
  10. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
  11. model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
  12. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
  13. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
  14. model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
  15. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
  16. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
  17. model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
  18. model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
  19. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
  20. model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
  21. model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
  22. model_compression_toolkit/qat/__init__.py +5 -2
  23. model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
  24. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
  25. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
  26. model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
  27. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
  28. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
  29. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
  30. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  31. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
  32. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
  37. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
  39. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
  40. model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
  41. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
  56. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
  57. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
  58. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
  59. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
  60. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
  61. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
  62. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
  63. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
  64. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
  65. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
  66. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
  67. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
  68. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
  69. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
  70. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
  71. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
  72. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
  73. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
  74. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
  75. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
  76. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
  77. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
  78. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
  79. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
  80. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,12 @@ from typing import Callable, Tuple
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
19
  from model_compression_toolkit.constants import TENSORFLOW
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
20
21
  from model_compression_toolkit.verify_packages import FOUND_TF
21
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
22
23
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
23
24
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
24
25
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
25
- from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
26
26
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
27
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
28
28
  from model_compression_toolkit.logger import Logger
@@ -35,6 +35,8 @@ if FOUND_TF:
35
35
  from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
36
36
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
37
37
  from tensorflow.keras.models import Model
38
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
39
+ AttachTpcToKeras
38
40
 
39
41
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
40
42
 
@@ -42,7 +44,7 @@ if FOUND_TF:
42
44
  target_resource_utilization: ResourceUtilization,
43
45
  representative_data_gen: Callable,
44
46
  pruning_config: PruningConfig = PruningConfig(),
45
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
47
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
46
48
  """
47
49
  Perform structured pruning on a Keras model to meet a specified target resource utilization.
48
50
  This function prunes the provided model according to the target resource utilization by grouping and pruning
@@ -111,6 +113,10 @@ if FOUND_TF:
111
113
  # Instantiate the Keras framework implementation.
112
114
  fw_impl = PruningKerasImplementation()
113
115
 
116
+ # Attach tpc model to framework
117
+ attach2keras = AttachTpcToKeras()
118
+ target_platform_capabilities = attach2keras.attach(target_platform_capabilities)
119
+
114
120
  # Convert the original Keras model to an internal graph representation.
115
121
  float_graph = read_model_to_graph(model,
116
122
  representative_data_gen,
@@ -16,12 +16,12 @@
16
16
  from typing import Callable, Tuple
17
17
  from model_compression_toolkit import get_target_platform_capabilities
18
18
  from model_compression_toolkit.constants import PYTORCH
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
19
20
  from model_compression_toolkit.verify_packages import FOUND_TORCH
20
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
21
22
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
22
23
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
23
24
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
24
- from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
25
25
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
26
26
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
27
27
  from model_compression_toolkit.logger import Logger
@@ -38,6 +38,8 @@ if FOUND_TORCH:
38
38
  PruningPytorchImplementation
39
39
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
40
40
  from torch.nn import Module
41
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
42
+ AttachTpcToPytorch
41
43
 
42
44
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
43
45
  DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -46,7 +48,7 @@ if FOUND_TORCH:
46
48
  target_resource_utilization: ResourceUtilization,
47
49
  representative_data_gen: Callable,
48
50
  pruning_config: PruningConfig = PruningConfig(),
49
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYOTRCH_TPC) -> \
51
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_PYOTRCH_TPC) -> \
50
52
  Tuple[Module, PruningInfo]:
51
53
  """
52
54
  Perform structured pruning on a Pytorch model to meet a specified target resource utilization.
@@ -117,6 +119,10 @@ if FOUND_TORCH:
117
119
  # Instantiate the Pytorch framework implementation.
118
120
  fw_impl = PruningPytorchImplementation()
119
121
 
122
+ # Attach TPC to framework
123
+ attach2pytorch = AttachTpcToPytorch()
124
+ target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities)
125
+
120
126
  # Convert the original Pytorch model to an internal graph representation.
121
127
  float_graph = read_model_to_graph(model,
122
128
  representative_data_gen,
@@ -22,6 +22,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
22
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW
25
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
25
26
  from model_compression_toolkit.verify_packages import FOUND_TF
26
27
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
27
28
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -41,6 +42,9 @@ if FOUND_TF:
41
42
 
42
43
  from model_compression_toolkit import get_target_platform_capabilities
43
44
  from mct_quantizers.keras.metadata import add_metadata
45
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
46
+ AttachTpcToKeras
47
+
44
48
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
45
49
 
46
50
 
@@ -48,7 +52,7 @@ if FOUND_TF:
48
52
  representative_data_gen: Callable,
49
53
  target_resource_utilization: ResourceUtilization = None,
50
54
  core_config: CoreConfig = CoreConfig(),
51
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
55
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
52
56
  """
53
57
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
54
58
  symmetric constraint quantization thresholds (power of two).
@@ -134,6 +138,11 @@ if FOUND_TF:
134
138
 
135
139
  fw_impl = KerasImplementation()
136
140
 
141
+ attach2keras = AttachTpcToKeras()
142
+ target_platform_capabilities = attach2keras.attach(
143
+ target_platform_capabilities,
144
+ custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
145
+
137
146
  # Ignore returned hessian service as PTQ does not use it
138
147
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
139
148
  representative_data_gen=representative_data_gen,
@@ -19,6 +19,7 @@ from typing import Callable
19
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import PYTORCH
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
23
  from model_compression_toolkit.verify_packages import FOUND_TORCH
23
24
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
24
25
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
@@ -39,6 +40,8 @@ if FOUND_TORCH:
39
40
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
40
41
  from model_compression_toolkit import get_target_platform_capabilities
41
42
  from mct_quantizers.pytorch.metadata import add_metadata
43
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
44
+ AttachTpcToPytorch
42
45
 
43
46
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
44
47
 
@@ -46,7 +49,7 @@ if FOUND_TORCH:
46
49
  representative_data_gen: Callable,
47
50
  target_resource_utilization: ResourceUtilization = None,
48
51
  core_config: CoreConfig = CoreConfig(),
49
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
52
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
50
53
  """
51
54
  Quantize a trained Pytorch module using post-training quantization.
52
55
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -107,6 +110,11 @@ if FOUND_TORCH:
107
110
 
108
111
  fw_impl = PytorchImplementation()
109
112
 
113
+ # Attach tpc model to framework
114
+ attach2pytorch = AttachTpcToPytorch()
115
+ target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
116
+ core_config.quantization_config.custom_tpc_opset_to_layer)
117
+
110
118
  # Ignore hessian info service as it is not used here yet.
111
119
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
112
120
  representative_data_gen=representative_data_gen,
@@ -13,6 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from model_compression_toolkit.qat.common.qat_config import QATConfig
16
+ from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH
16
17
 
17
- from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init_experimental, keras_quantization_aware_training_finalize_experimental
18
- from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_finalize_experimental
18
+ if FOUND_TF:
19
+ from model_compression_toolkit.qat.keras.quantization_facade import keras_quantization_aware_training_init_experimental, keras_quantization_aware_training_finalize_experimental
20
+ if FOUND_TORCH:
21
+ from model_compression_toolkit.qat.pytorch.quantization_facade import pytorch_quantization_aware_training_init_experimental, pytorch_quantization_aware_training_finalize_experimental
@@ -19,6 +19,7 @@ from functools import partial
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.logger import Logger
22
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
22
23
  from model_compression_toolkit.verify_packages import FOUND_TF
23
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
24
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -54,6 +55,8 @@ if FOUND_TF:
54
55
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
55
56
  get_activation_quantizer_holder
56
57
  from model_compression_toolkit.qat.common.qat_config import QATConfig
58
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
59
+ AttachTpcToKeras
57
60
 
58
61
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
59
62
 
@@ -90,7 +93,7 @@ if FOUND_TF:
90
93
  target_resource_utilization: ResourceUtilization = None,
91
94
  core_config: CoreConfig = CoreConfig(),
92
95
  qat_config: QATConfig = QATConfig(),
93
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
96
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC):
94
97
  """
95
98
  Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
96
99
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -186,6 +189,11 @@ if FOUND_TF:
186
189
 
187
190
  fw_impl = KerasImplementation()
188
191
 
192
+ attach2keras = AttachTpcToKeras()
193
+ target_platform_capabilities = attach2keras.attach(
194
+ target_platform_capabilities,
195
+ custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
196
+
189
197
  # Ignore hessian service since is not used in QAT at the moment
190
198
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
191
199
  representative_data_gen=representative_data_gen,
@@ -17,6 +17,9 @@ from typing import Callable
17
17
  from functools import partial
18
18
 
19
19
  from model_compression_toolkit.constants import PYTORCH
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
21
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
22
+ AttachTpcToPytorch
20
23
  from model_compression_toolkit.verify_packages import FOUND_TORCH
21
24
 
22
25
  from model_compression_toolkit.core import CoreConfig
@@ -79,7 +82,7 @@ if FOUND_TORCH:
79
82
  target_resource_utilization: ResourceUtilization = None,
80
83
  core_config: CoreConfig = CoreConfig(),
81
84
  qat_config: QATConfig = QATConfig(),
82
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
85
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
83
86
  """
84
87
  Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
85
88
  with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is
@@ -154,6 +157,11 @@ if FOUND_TORCH:
154
157
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
155
158
  fw_impl = PytorchImplementation()
156
159
 
160
+ # Attach tpc model to framework
161
+ attach2pytorch = AttachTpcToPytorch()
162
+ target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
163
+ core_config.quantization_config.custom_tpc_opset_to_layer)
164
+
157
165
  # Ignore hessian scores service as we do not use it here
158
166
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
159
167
  representative_data_gen=representative_data_gen,
@@ -7,6 +7,6 @@ OpQuantizationConfig = schema.OpQuantizationConfig
7
7
  QuantizationConfigOptions = schema.QuantizationConfigOptions
8
8
  OperatorsSetBase = schema.OperatorsSetBase
9
9
  OperatorsSet = schema.OperatorsSet
10
- OperatorSetConcat= schema.OperatorSetConcat
10
+ OperatorSetConcat = schema.OperatorSetConcat
11
11
  Fusing = schema.Fusing
12
12
  TargetPlatformModel = schema.TargetPlatformModel
@@ -13,66 +13,74 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import pprint
16
-
17
16
  from enum import Enum
18
17
  from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
18
+
19
+ from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
20
+
19
21
  from mct_quantizers import QuantizationMethod
20
22
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
21
23
  from model_compression_toolkit.logger import Logger
22
- from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, PrivateAttr
23
-
24
-
25
- class OperatorSetNames(Enum):
26
- OPSET_CONV = "Conv"
27
- OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
28
- OPSET_CONV_TRANSPOSE = "ConvTranspose"
29
- OPSET_FULLY_CONNECTED = "FullyConnected"
30
- OPSET_CONCATENATE = "Concatenate"
31
- OPSET_STACK = "Stack"
32
- OPSET_UNSTACK = "Unstack"
33
- OPSET_GATHER = "Gather"
34
- OPSET_EXPAND = "Expend"
35
- OPSET_BATCH_NORM = "BatchNorm"
36
- OPSET_RELU = "ReLU"
37
- OPSET_RELU6 = "ReLU6"
38
- OPSET_LEAKY_RELU = "LEAKYReLU"
39
- OPSET_HARD_TANH = "HardTanh"
40
- OPSET_ADD = "Add"
41
- OPSET_SUB = "Sub"
42
- OPSET_MUL = "Mul"
43
- OPSET_DIV = "Div"
44
- OPSET_MIN = "Min"
45
- OPSET_MAX = "Max"
46
- OPSET_PRELU = "PReLU"
47
- OPSET_SWISH = "Swish"
48
- OPSET_SIGMOID = "Sigmoid"
49
- OPSET_TANH = "Tanh"
50
- OPSET_GELU = "Gelu"
51
- OPSET_HARDSIGMOID = "HardSigmoid"
52
- OPSET_HARDSWISH = "HardSwish"
53
- OPSET_FLATTEN = "Flatten"
54
- OPSET_GET_ITEM = "GetItem"
55
- OPSET_RESHAPE = "Reshape"
56
- OPSET_UNSQUEEZE = "Unsqueeze"
57
- OPSET_SQUEEZE = "Squeeze"
58
- OPSET_PERMUTE = "Permute"
59
- OPSET_TRANSPOSE = "Transpose"
60
- OPSET_DROPOUT = "Dropout"
61
- OPSET_SPLIT = "Split"
62
- OPSET_CHUNK = "Chunk"
63
- OPSET_MAXPOOL = "MaxPool"
64
- OPSET_SIZE = "Size"
65
- OPSET_SHAPE = "Shape"
66
- OPSET_EQUAL = "Equal"
67
- OPSET_ARGMAX = "ArgMax"
68
- OPSET_TOPK = "TopK"
69
- OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS = "FakeQuantWithMinMaxVars"
70
- OPSET_COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
71
- OPSET_CROPPING2D = "Cropping2D"
72
- OPSET_ZERO_PADDING2d = "ZeroPadding2D"
73
- OPSET_CAST = "Cast"
74
- OPSET_STRIDED_SLICE = "StridedSlice"
75
- OPSET_SSD_POST_PROCESS = "SSDPostProcess"
24
+
25
+
26
+ class OperatorSetNames(str, Enum):
27
+ CONV = "Conv"
28
+ DEPTHWISE_CONV = "DepthwiseConv2D"
29
+ CONV_TRANSPOSE = "ConvTranspose"
30
+ FULLY_CONNECTED = "FullyConnected"
31
+ CONCATENATE = "Concatenate"
32
+ STACK = "Stack"
33
+ UNSTACK = "Unstack"
34
+ GATHER = "Gather"
35
+ EXPAND = "Expend"
36
+ BATCH_NORM = "BatchNorm"
37
+ L2NORM = "L2Norm"
38
+ RELU = "ReLU"
39
+ RELU6 = "ReLU6"
40
+ LEAKY_RELU = "LeakyReLU"
41
+ ELU = "Elu"
42
+ HARD_TANH = "HardTanh"
43
+ ADD = "Add"
44
+ SUB = "Sub"
45
+ MUL = "Mul"
46
+ DIV = "Div"
47
+ MIN = "Min"
48
+ MAX = "Max"
49
+ PRELU = "PReLU"
50
+ ADD_BIAS = "AddBias"
51
+ SWISH = "Swish"
52
+ SIGMOID = "Sigmoid"
53
+ SOFTMAX = "Softmax"
54
+ LOG_SOFTMAX = "LogSoftmax"
55
+ TANH = "Tanh"
56
+ GELU = "Gelu"
57
+ HARDSIGMOID = "HardSigmoid"
58
+ HARDSWISH = "HardSwish"
59
+ FLATTEN = "Flatten"
60
+ GET_ITEM = "GetItem"
61
+ RESHAPE = "Reshape"
62
+ UNSQUEEZE = "Unsqueeze"
63
+ SQUEEZE = "Squeeze"
64
+ PERMUTE = "Permute"
65
+ TRANSPOSE = "Transpose"
66
+ DROPOUT = "Dropout"
67
+ SPLIT_CHUNK = "SplitChunk"
68
+ MAXPOOL = "MaxPool"
69
+ AVGPOOL = "AvgPool"
70
+ SIZE = "Size"
71
+ SHAPE = "Shape"
72
+ EQUAL = "Equal"
73
+ ARGMAX = "ArgMax"
74
+ TOPK = "TopK"
75
+ FAKE_QUANT = "FakeQuant"
76
+ COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
77
+ ZERO_PADDING2D = "ZeroPadding2D"
78
+ CAST = "Cast"
79
+ RESIZE = "Resize"
80
+ PAD = "Pad"
81
+ FOLD = "Fold"
82
+ STRIDED_SLICE = "StridedSlice"
83
+ SSD_POST_PROCESS = "SSDPostProcess"
76
84
 
77
85
  @classmethod
78
86
  def get_values(cls):
@@ -1,12 +1,15 @@
1
- from typing import Dict, Tuple, List, Any, Optional
1
+ from typing import Dict, Optional
2
2
 
3
- from model_compression_toolkit import DefaultDict
4
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
3
+ from model_compression_toolkit.logger import Logger
4
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, \
5
+ OperatorsSet
5
6
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
6
7
  OperationsSetToLayers
7
8
 
9
+ from model_compression_toolkit.core.common.quantization.quantization_config import CustomOpsetLayers
8
10
 
9
- class AttachTpModelToFw:
11
+
12
+ class AttachTpcToFramework:
10
13
 
11
14
  def __init__(self):
12
15
  self._opset2layer = None
@@ -17,7 +20,7 @@ class AttachTpModelToFw:
17
20
  self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
18
21
 
19
22
  def attach(self, tpc_model: TargetPlatformModel,
20
- custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None
23
+ custom_opset2layer: Optional[Dict[str, 'CustomOpsetLayers']] = None
21
24
  ) -> TargetPlatformCapabilities:
22
25
  """
23
26
  Attaching a TargetPlatformModel which includes a platform capabilities description to specific
@@ -35,22 +38,30 @@ class AttachTpModelToFw:
35
38
  """
36
39
 
37
40
  tpc = TargetPlatformCapabilities(tpc_model)
41
+ custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
38
42
 
39
43
  with tpc:
40
- for opset_name, operators in self._opset2layer.items():
41
- attr_mapping = self._opset2attr_mapping.get(opset_name)
42
- OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping)
43
-
44
- if custom_opset2layer is not None:
45
- for opset_name, operators in custom_opset2layer.items():
46
- if len(operators) == 1:
47
- OperationsSetToLayers(opset_name, operators[0])
48
- elif len(operators) == 2:
49
- OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1])
44
+ for opset in tpc_model.operator_set:
45
+ if isinstance(opset, OperatorsSet): # filter out OperatorsSetConcat
46
+ if opset.name in custom_opset2layer:
47
+ custom_opset_layers = custom_opset2layer[opset.name]
48
+ OperationsSetToLayers(opset.name,
49
+ layers=custom_opset_layers.operators,
50
+ attr_mapping=custom_opset_layers.attr_mapping)
51
+
52
+ elif opset.name in self._opset2layer:
53
+ # Note that if the user provided a custom operator set with a name that exists in our
54
+ # pre-defined set of operator sets, we prioritize the user's custom opset definition
55
+ layers = self._opset2layer[opset.name]
56
+ if len(layers) > 0:
57
+ # If the framework does not define any matching operators to a given operator set name that
58
+ # appears in the TPC, then we just skip it
59
+ attr_mapping = self._opset2attr_mapping.get(opset.name)
60
+ OperationsSetToLayers(opset.name, layers, attr_mapping=attr_mapping)
50
61
  else:
51
- raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - "
52
- f"a list of layers to attach to the operator and an optional mapping of "
53
- f"attributes names, but given a mapping contains {len(operators)} elements.")
62
+ Logger.critical(f'{opset.name} is defined in TargetPlatformModel, '
63
+ f'but is not defined in the framework set of operators or in the provided '
64
+ f'custom operator sets mapping.')
54
65
 
55
66
  return tpc
56
67
 
@@ -23,12 +23,12 @@ if FOUND_SONY_CUSTOM_LAYERS:
23
23
 
24
24
  if version.parse(tf.__version__) >= version.parse("2.13"):
25
25
  from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
26
- MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
27
- Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
26
+ MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
27
+ Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
28
28
  else:
29
29
  from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
30
- MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
31
- Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
30
+ MaxPooling2D, AveragePooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
31
+ Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum, Softmax
32
32
 
33
33
  from model_compression_toolkit import DefaultDict
34
34
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
@@ -36,72 +36,93 @@ from model_compression_toolkit.target_platform_capabilities.constants import KER
36
36
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
37
37
  from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
38
38
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
39
- AttachTpModelToFw
39
+ AttachTpcToFramework
40
40
 
41
41
 
42
- class AttachTpModelToKeras(AttachTpModelToFw):
42
+ class AttachTpcToKeras(AttachTpcToFramework):
43
43
  def __init__(self):
44
44
  super().__init__()
45
45
 
46
46
  self._opset2layer = {
47
- OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d],
48
- OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
49
- OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose],
50
- OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense],
51
- OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate],
52
- OperatorSetNames.OPSET_STACK.value: [tf.stack],
53
- OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack],
54
- OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather],
55
- OperatorSetNames.OPSET_EXPAND.value: [],
56
- OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization],
57
- OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU],
58
- OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6],
59
- OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU],
60
- OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")],
61
- OperatorSetNames.OPSET_ADD.value: [tf.add, Add],
62
- OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
63
- OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
64
- OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
65
- OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
66
- OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
67
- OperatorSetNames.OPSET_PRELU.value: [PReLU],
68
- OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
69
- OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
70
- OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
71
- OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
72
- OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid,
73
- LayerFilterParams(Activation, activation="hard_sigmoid")],
74
- OperatorSetNames.OPSET_FLATTEN.value: [Flatten],
75
- OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem],
76
- OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape],
77
- OperatorSetNames.OPSET_PERMUTE.value: [Permute],
78
- OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose],
79
- OperatorSetNames.OPSET_DROPOUT.value: [Dropout],
80
- OperatorSetNames.OPSET_SPLIT.value: [tf.split],
81
- OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D],
82
- OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape],
83
- OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal],
84
- OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax],
85
- OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k],
86
- OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars],
87
- OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression],
88
- OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D],
89
- OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D],
90
- OperatorSetNames.OPSET_CAST.value: [tf.cast],
91
- OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice]
47
+ OperatorSetNames.CONV: [Conv2D, tf.nn.conv2d],
48
+ OperatorSetNames.DEPTHWISE_CONV: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
49
+ OperatorSetNames.CONV_TRANSPOSE: [Conv2DTranspose, tf.nn.conv2d_transpose],
50
+ OperatorSetNames.FULLY_CONNECTED: [Dense],
51
+ OperatorSetNames.CONCATENATE: [tf.concat, Concatenate],
52
+ OperatorSetNames.STACK: [tf.stack],
53
+ OperatorSetNames.UNSTACK: [tf.unstack],
54
+ OperatorSetNames.GATHER: [tf.gather, tf.compat.v1.gather],
55
+ OperatorSetNames.EXPAND: [],
56
+ OperatorSetNames.BATCH_NORM: [BatchNormalization, tf.nn.batch_normalization],
57
+ OperatorSetNames.RELU: [tf.nn.relu, ReLU, LayerFilterParams(Activation, activation="relu")],
58
+ OperatorSetNames.RELU6: [tf.nn.relu6],
59
+ OperatorSetNames.LEAKY_RELU: [tf.nn.leaky_relu, LeakyReLU, LayerFilterParams(Activation, activation="leaky_relu")],
60
+ OperatorSetNames.HARD_TANH: [LayerFilterParams(Activation, activation="hard_tanh")],
61
+ OperatorSetNames.ADD: [tf.add, Add],
62
+ OperatorSetNames.SUB: [tf.subtract, Subtract],
63
+ OperatorSetNames.MUL: [tf.math.multiply, Multiply],
64
+ OperatorSetNames.DIV: [tf.math.divide, tf.math.truediv],
65
+ OperatorSetNames.MIN: [tf.math.minimum, Minimum],
66
+ OperatorSetNames.MAX: [tf.math.maximum, Maximum],
67
+ OperatorSetNames.PRELU: [PReLU],
68
+ OperatorSetNames.SWISH: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
69
+ OperatorSetNames.HARDSWISH: [LayerFilterParams(Activation, activation="hard_swish")],
70
+ OperatorSetNames.SIGMOID: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
71
+ OperatorSetNames.TANH: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
72
+ OperatorSetNames.GELU: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
73
+ OperatorSetNames.HARDSIGMOID: [tf.keras.activations.hard_sigmoid,
74
+ LayerFilterParams(Activation, activation="hard_sigmoid")],
75
+ OperatorSetNames.FLATTEN: [Flatten],
76
+ OperatorSetNames.GET_ITEM: [tf.__operators__.getitem],
77
+ OperatorSetNames.RESHAPE: [Reshape, tf.reshape],
78
+ OperatorSetNames.PERMUTE: [Permute],
79
+ OperatorSetNames.TRANSPOSE: [tf.transpose],
80
+ OperatorSetNames.UNSQUEEZE: [tf.expand_dims],
81
+ OperatorSetNames.SQUEEZE: [tf.squeeze],
82
+ OperatorSetNames.DROPOUT: [Dropout],
83
+ OperatorSetNames.SPLIT_CHUNK: [tf.split],
84
+ OperatorSetNames.MAXPOOL: [MaxPooling2D, tf.nn.avg_pool2d],
85
+ OperatorSetNames.AVGPOOL: [AveragePooling2D],
86
+ OperatorSetNames.SIZE: [tf.size],
87
+ OperatorSetNames.RESIZE: [tf.image.resize],
88
+ OperatorSetNames.PAD: [tf.pad, Cropping2D],
89
+ OperatorSetNames.FOLD: [tf.space_to_batch_nd],
90
+ OperatorSetNames.SHAPE: [tf.shape, tf.compat.v1.shape],
91
+ OperatorSetNames.EQUAL: [tf.math.equal],
92
+ OperatorSetNames.ARGMAX: [tf.math.argmax],
93
+ OperatorSetNames.TOPK: [tf.nn.top_k],
94
+ OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
95
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
96
+ OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
97
+ OperatorSetNames.CAST: [tf.cast],
98
+ OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
99
+ OperatorSetNames.ELU: [tf.nn.elu, LayerFilterParams(Activation, activation="elu")],
100
+ OperatorSetNames.SOFTMAX: [tf.nn.softmax, Softmax,
101
+ LayerFilterParams(Activation, activation="softmax")],
102
+ OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
103
+ OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
104
+ OperatorSetNames.L2NORM: [tf.math.l2_normalize],
92
105
  }
93
106
 
94
107
  if FOUND_SONY_CUSTOM_LAYERS:
95
- self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess]
108
+ self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
109
+ else:
110
+ # If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
111
+ # in the initialized framework TPC
112
+ self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []
96
113
 
97
- self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: {
98
- KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
99
- BIAS_ATTR: DefaultDict(default_value=BIAS)},
100
- OperatorSetNames.OPSET_DEPTHWISE_CONV.value: {
114
+ self._opset2attr_mapping = {
115
+ OperatorSetNames.CONV: {
116
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
117
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
118
+ OperatorSetNames.CONV_TRANSPOSE: {
119
+ KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
120
+ BIAS_ATTR: DefaultDict(default_value=BIAS)},
121
+ OperatorSetNames.DEPTHWISE_CONV: {
101
122
  KERNEL_ATTR: DefaultDict({
102
123
  DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
103
124
  tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
104
125
  BIAS_ATTR: DefaultDict(default_value=BIAS)},
105
- OperatorSetNames.OPSET_FULLY_CONNECTED.value: {
126
+ OperatorSetNames.FULLY_CONNECTED: {
106
127
  KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
107
128
  BIAS_ATTR: DefaultDict(default_value=BIAS)}}