mct-nightly 2.2.0.20250106.546__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/__init__.py +1 -1
  5. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
  6. model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
  7. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
  8. model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
  9. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
  10. model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
  11. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
  12. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
  13. model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
  14. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
  15. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
  16. model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
  17. model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
  18. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
  19. model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
  20. model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
  21. model_compression_toolkit/qat/__init__.py +5 -2
  22. model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
  23. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
  24. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
  25. model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
  26. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
  27. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
  28. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
  29. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  30. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
  31. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
  32. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
  37. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
  39. model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
  40. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
  41. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
  56. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
  57. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
  58. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
  59. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
  60. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
  61. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
  62. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
  63. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
  64. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
  65. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
  66. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
  67. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
  68. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
  69. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
  70. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
  71. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
  72. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
  73. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
  74. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
  75. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
  76. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
  77. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
  78. {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
  79. {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +0 -0
  80. {mct_nightly-2.2.0.20250106.546.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -18,74 +18,89 @@ import operator
18
18
  import torch
19
19
  from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
20
20
  chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract, minimum, \
21
- maximum
22
- from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d
23
- from torch.nn import Dropout, Flatten, Hardtanh
24
- from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU
21
+ maximum, softmax, fake_quantize_per_channel_affine
22
+ from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d, BatchNorm2d, Dropout, Flatten, Hardtanh, ReLU, ReLU6, \
23
+ PReLU, SiLU, Sigmoid, Tanh, Hardswish, Hardsigmoid, LeakyReLU, GELU, LogSoftmax, Softmax, ELU, AvgPool2d, ZeroPad2d
24
+ from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu, fold
25
25
  import torch.nn.functional as F
26
- from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, hardsigmoid, leaky_relu, gelu
27
26
 
28
27
  from model_compression_toolkit import DefaultDict
29
28
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS, \
30
29
  BIAS_ATTR
31
30
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
32
- from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams, Eq
33
32
  from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
34
- AttachTpModelToFw
33
+ AttachTpcToFramework
35
34
 
36
35
 
37
- class AttachTpModelToPytorch(AttachTpModelToFw):
36
+ class AttachTpcToPytorch(AttachTpcToFramework):
38
37
  def __init__(self):
39
38
  super().__init__()
40
39
 
41
40
  self._opset2layer = {
42
- OperatorSetNames.OPSET_CONV.value: [Conv2d],
43
- OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [ConvTranspose2d],
44
- OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Linear],
45
- OperatorSetNames.OPSET_CONCATENATE.value: [torch.cat, torch.concat, torch.concatenate],
46
- OperatorSetNames.OPSET_STACK.value: [torch.stack],
47
- OperatorSetNames.OPSET_UNSTACK.value: [unbind],
48
- OperatorSetNames.OPSET_GATHER.value: [gather],
49
- OperatorSetNames.OPSET_EXPAND.value: [torch.Tensor.expand],
50
- OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNorm2d],
51
- OperatorSetNames.OPSET_RELU.value: [torch.relu, ReLU, relu],
52
- OperatorSetNames.OPSET_RELU6.value: [ReLU6, relu6],
53
- OperatorSetNames.OPSET_LEAKY_RELU.value: [LeakyReLU, leaky_relu],
54
- OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Hardtanh, min_val=0),
55
- LayerFilterParams(hardtanh, min_val=0)],
56
- OperatorSetNames.OPSET_ADD.value: [operator.add, add],
57
- OperatorSetNames.OPSET_SUB.value: [operator.sub, sub, subtract],
58
- OperatorSetNames.OPSET_MUL.value: [operator.mul, mul, multiply],
59
- OperatorSetNames.OPSET_DIV.value: [operator.truediv, div, divide],
60
- OperatorSetNames.OPSET_MIN.value: [minimum],
61
- OperatorSetNames.OPSET_MAX.value: [maximum],
62
- OperatorSetNames.OPSET_PRELU.value: [PReLU, prelu],
63
- OperatorSetNames.OPSET_SWISH.value: [SiLU, silu],
64
- OperatorSetNames.OPSET_SIGMOID.value: [Sigmoid, sigmoid, F.sigmoid],
65
- OperatorSetNames.OPSET_TANH.value: [Tanh, tanh, F.tanh],
66
- OperatorSetNames.OPSET_GELU.value: [GELU, gelu],
67
- OperatorSetNames.OPSET_HARDSIGMOID.value: [Hardsigmoid, hardsigmoid],
68
- OperatorSetNames.OPSET_HARDSWISH.value: [Hardswish, hardswish],
69
- OperatorSetNames.OPSET_FLATTEN.value: [Flatten, flatten],
70
- OperatorSetNames.OPSET_GET_ITEM.value: [operator.getitem],
71
- OperatorSetNames.OPSET_RESHAPE.value: [reshape],
72
- OperatorSetNames.OPSET_UNSQUEEZE.value: [unsqueeze],
73
- OperatorSetNames.OPSET_SQUEEZE.value: [squeeze],
74
- OperatorSetNames.OPSET_PERMUTE.value: [permute],
75
- OperatorSetNames.OPSET_TRANSPOSE.value: [transpose],
76
- OperatorSetNames.OPSET_DROPOUT.value: [Dropout, dropout],
77
- OperatorSetNames.OPSET_SPLIT.value: [split],
78
- OperatorSetNames.OPSET_CHUNK.value: [chunk],
79
- OperatorSetNames.OPSET_MAXPOOL.value: [MaxPool2d],
80
- OperatorSetNames.OPSET_SIZE.value: [torch.Tensor.size],
81
- OperatorSetNames.OPSET_SHAPE.value: [torch.Tensor.shape],
82
- OperatorSetNames.OPSET_EQUAL.value: [equal],
83
- OperatorSetNames.OPSET_ARGMAX.value: [argmax],
84
- OperatorSetNames.OPSET_TOPK.value: [topk],
41
+ OperatorSetNames.CONV: [Conv2d],
42
+ OperatorSetNames.DEPTHWISE_CONV: [], # no specific operator for depthwise conv in pytorch
43
+ OperatorSetNames.CONV_TRANSPOSE: [ConvTranspose2d],
44
+ OperatorSetNames.FULLY_CONNECTED: [Linear],
45
+ OperatorSetNames.CONCATENATE: [torch.cat, torch.concat, torch.concatenate],
46
+ OperatorSetNames.STACK: [torch.stack],
47
+ OperatorSetNames.UNSTACK: [unbind],
48
+ OperatorSetNames.GATHER: [gather],
49
+ OperatorSetNames.EXPAND: [torch.Tensor.expand],
50
+ OperatorSetNames.BATCH_NORM: [BatchNorm2d],
51
+ OperatorSetNames.RELU: [torch.relu, ReLU, relu],
52
+ OperatorSetNames.RELU6: [ReLU6, relu6],
53
+ OperatorSetNames.LEAKY_RELU: [LeakyReLU, leaky_relu],
54
+ OperatorSetNames.HARD_TANH: [LayerFilterParams(Hardtanh, min_val=0),
55
+ LayerFilterParams(hardtanh, min_val=0)],
56
+ OperatorSetNames.ADD: [operator.add, add],
57
+ OperatorSetNames.SUB: [operator.sub, sub, subtract],
58
+ OperatorSetNames.MUL: [operator.mul, mul, multiply],
59
+ OperatorSetNames.DIV: [operator.truediv, div, divide],
60
+ OperatorSetNames.ADD_BIAS: [], # no specific operator for bias_add in pytorch
61
+ OperatorSetNames.MIN: [minimum],
62
+ OperatorSetNames.MAX: [maximum],
63
+ OperatorSetNames.PRELU: [PReLU, prelu],
64
+ OperatorSetNames.SWISH: [SiLU, silu],
65
+ OperatorSetNames.SIGMOID: [Sigmoid, sigmoid, F.sigmoid],
66
+ OperatorSetNames.TANH: [Tanh, tanh, F.tanh],
67
+ OperatorSetNames.GELU: [GELU, gelu],
68
+ OperatorSetNames.HARDSIGMOID: [Hardsigmoid, hardsigmoid],
69
+ OperatorSetNames.HARDSWISH: [Hardswish, hardswish],
70
+ OperatorSetNames.FLATTEN: [Flatten, flatten],
71
+ OperatorSetNames.GET_ITEM: [operator.getitem],
72
+ OperatorSetNames.RESHAPE: [reshape],
73
+ OperatorSetNames.UNSQUEEZE: [unsqueeze],
74
+ OperatorSetNames.SQUEEZE: [squeeze],
75
+ OperatorSetNames.PERMUTE: [permute],
76
+ OperatorSetNames.TRANSPOSE: [transpose],
77
+ OperatorSetNames.DROPOUT: [Dropout, dropout],
78
+ OperatorSetNames.SPLIT_CHUNK: [split, chunk],
79
+ OperatorSetNames.MAXPOOL: [MaxPool2d, F.max_pool2d],
80
+ OperatorSetNames.AVGPOOL: [AvgPool2d, F.avg_pool2d],
81
+ OperatorSetNames.SIZE: [torch.Tensor.size],
82
+ OperatorSetNames.RESIZE: [torch.Tensor.resize],
83
+ OperatorSetNames.PAD: [F.pad],
84
+ OperatorSetNames.FOLD: [fold],
85
+ OperatorSetNames.SHAPE: [torch.Tensor.shape],
86
+ OperatorSetNames.EQUAL: [equal],
87
+ OperatorSetNames.ARGMAX: [argmax],
88
+ OperatorSetNames.TOPK: [topk],
89
+ OperatorSetNames.FAKE_QUANT: [fake_quantize_per_channel_affine],
90
+ OperatorSetNames.ZERO_PADDING2D: [ZeroPad2d],
91
+ OperatorSetNames.CAST: [torch.Tensor.type],
92
+ OperatorSetNames.STRIDED_SLICE: [], # no such operator in pytorch, the equivalent is get_item which has a separate operator set
93
+ OperatorSetNames.ELU: [ELU, F.elu],
94
+ OperatorSetNames.SOFTMAX: [Softmax, softmax, F.softmax],
95
+ OperatorSetNames.LOG_SOFTMAX: [LogSoftmax],
96
+ OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
97
+ Eq('p', 2) | Eq('p', None))],
98
+ OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
99
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
85
100
  }
86
101
 
87
102
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
88
103
  BIAS_ATTR: DefaultDict(default_value=BIAS)}
89
- self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: pytorch_linear_attr_mapping,
90
- OperatorSetNames.OPSET_CONV_TRANSPOSE.value: pytorch_linear_attr_mapping,
91
- OperatorSetNames.OPSET_FULLY_CONNECTED.value: pytorch_linear_attr_mapping}
104
+ self._opset2attr_mapping = {OperatorSetNames.CONV: pytorch_linear_attr_mapping,
105
+ OperatorSetNames.CONV_TRANSPOSE: pytorch_linear_attr_mapping,
106
+ OperatorSetNames.FULLY_CONNECTED: pytorch_linear_attr_mapping}
@@ -138,10 +138,8 @@ class OperationsToLayers:
138
138
  OperationsSetToLayers), f'Operators set should be of type OperationsSetToLayers but it ' \
139
139
  f'is of type {type(ops2layers)}'
140
140
 
141
- # Assert that opset in the current TargetPlatformCapabilities and has a unique name.
142
- opset_in_model = is_opset_in_model(_current_tpc.get().tp_model, ops2layers.name)
143
- assert opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.'
144
- assert not (ops2layers.name in existing_opset_names), f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
141
+ # Assert that opset has a unique name.
142
+ assert ops2layers.name not in existing_opset_names, f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.'
145
143
  existing_opset_names.append(ops2layers.name)
146
144
 
147
145
  # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel.
@@ -156,7 +156,6 @@ class TargetPlatformCapabilities(ImmutableClass):
156
156
  if exc_value is not None:
157
157
  print(exc_value, exc_value.args)
158
158
  raise exc_value
159
- self.raise_warnings()
160
159
  self.layer2qco, self.filterlayer2qco = self._get_config_options_mapping()
161
160
  _current_tpc.reset()
162
161
  self.initialized_done()
@@ -226,15 +225,6 @@ class TargetPlatformCapabilities(ImmutableClass):
226
225
  if opset_to_remove in self.__tp_model_opsets_not_used:
227
226
  self.__tp_model_opsets_not_used.remove(opset_to_remove)
228
227
 
229
- def raise_warnings(self):
230
- """
231
-
232
- Log warnings regards unused opsets.
233
-
234
- """
235
- for op in self.__tp_model_opsets_not_used:
236
- Logger.warning(f'{op} is defined in TargetPlatformModel, but is not used in TargetPlatformCapabilities.')
237
-
238
228
  @property
239
229
  def is_simd_padding(self) -> bool:
240
230
  """
@@ -0,0 +1,93 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from pathlib import Path
16
+ from typing import Union
17
+
18
+ from model_compression_toolkit.logger import Logger
19
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
20
+ import json
21
+
22
+
23
+ def load_target_platform_model(tp_model_or_path: Union[TargetPlatformModel, str]) -> TargetPlatformModel:
24
+ """
25
+ Parses the tp_model input, which can be either a TargetPlatformModel object
26
+ or a string path to a JSON file.
27
+
28
+ Parameters:
29
+ tp_model_or_path (Union[TargetPlatformModel, str]): Input target platform model or path to .JSON file.
30
+
31
+ Returns:
32
+ TargetPlatformModel: The parsed TargetPlatformModel.
33
+
34
+ Raises:
35
+ FileNotFoundError: If the JSON file does not exist.
36
+ ValueError: If the JSON content is invalid or cannot initialize the TargetPlatformModel.
37
+ TypeError: If the input is neither a TargetPlatformModel nor a valid JSON file path.
38
+ """
39
+ if isinstance(tp_model_or_path, TargetPlatformModel):
40
+ return tp_model_or_path
41
+
42
+ if isinstance(tp_model_or_path, str):
43
+ path = Path(tp_model_or_path)
44
+
45
+ if not path.exists() or not path.is_file():
46
+ raise FileNotFoundError(f"The path '{tp_model_or_path}' is not a valid file.")
47
+ # Verify that the file has a .json extension
48
+ if path.suffix.lower() != '.json':
49
+ raise ValueError(f"The file '{path}' does not have a '.json' extension.")
50
+ try:
51
+ with path.open('r', encoding='utf-8') as file:
52
+ data = file.read()
53
+ except OSError as e:
54
+ raise ValueError(f"Error reading the file '{tp_model_or_path}': {e.strerror}.") from e
55
+
56
+ try:
57
+ return TargetPlatformModel.parse_raw(data)
58
+ except ValueError as e:
59
+ raise ValueError(f"Invalid JSON for loading TargetPlatformModel in '{tp_model_or_path}': {e}.") from e
60
+ except Exception as e:
61
+ raise ValueError(f"Unexpected error while initializing TargetPlatformModel: {e}.") from e
62
+
63
+ raise TypeError(
64
+ f"tp_model_or_path must be either a TargetPlatformModel instance or a string path to a JSON file, "
65
+ f"but received type '{type(tp_model_or_path).__name__}'."
66
+ )
67
+
68
+
69
+ def export_target_platform_model(model: TargetPlatformModel, export_path: Union[str, Path]) -> None:
70
+ """
71
+ Exports a TargetPlatformModel instance to a JSON file.
72
+
73
+ Parameters:
74
+ model (TargetPlatformModel): The TargetPlatformModel instance to export.
75
+ export_path (Union[str, Path]): The file path to export the model to.
76
+
77
+ Raises:
78
+ ValueError: If the model is not an instance of TargetPlatformModel.
79
+ OSError: If there is an issue writing to the file.
80
+ """
81
+ if not isinstance(model, TargetPlatformModel):
82
+ raise ValueError("The provided model is not a valid TargetPlatformModel instance.")
83
+
84
+ path = Path(export_path)
85
+ try:
86
+ # Ensure the parent directory exists
87
+ path.parent.mkdir(parents=True, exist_ok=True)
88
+
89
+ # Export the model to JSON and write to the file
90
+ with path.open('w', encoding='utf-8') as file:
91
+ file.write(model.json(indent=4))
92
+ except OSError as e:
93
+ raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
@@ -12,45 +12,63 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from model_compression_toolkit.constants import TENSORFLOW, PYTORCH
16
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, \
17
+ TFLITE_TP_MODEL, QNNPACK_TP_MODEL
18
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
15
19
 
16
- from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_tp_model_imx500_v1
21
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model as get_tp_model_tflite_v1
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as get_tp_model_qnnpack_v1
17
23
 
18
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.target_platform_capabilities import \
19
- get_tpc_dict_by_fw as get_imx500_tpc
20
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.target_platform_capabilities import \
21
- get_tpc_dict_by_fw as get_tflite_tpc
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.target_platform_capabilities import \
23
- get_tpc_dict_by_fw as get_qnnpack_tpc
24
- from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL, IMX500_TP_MODEL, TFLITE_TP_MODEL, QNNPACK_TP_MODEL, LATEST
25
-
26
- tpc_dict = {DEFAULT_TP_MODEL: get_imx500_tpc,
27
- IMX500_TP_MODEL: get_imx500_tpc,
28
- TFLITE_TP_MODEL: get_tflite_tpc,
29
- QNNPACK_TP_MODEL: get_qnnpack_tpc}
30
24
 
25
+ # TODO: These methods need to be replaced once modifying the TPC API.
31
26
 
32
27
  def get_target_platform_capabilities(fw_name: str,
33
28
  target_platform_name: str,
34
- target_platform_version: str = None) -> TargetPlatformCapabilities:
29
+ target_platform_version: str = None) -> TargetPlatformModel:
35
30
  """
36
- Get a TargetPlatformCapabilities by the target platform model name and the framework name.
37
- For now, it supports frameworks 'tensorflow' and 'pytorch'. For both of them
38
- the target platform model can be 'default', 'imx500', 'tflite', or 'qnnpack'.
31
+ This is a degenerated function that only returns the MCT default TargetPlatformModel object, to comply with the
32
+ existing TPC API.
39
33
 
40
34
  Args:
41
35
  fw_name: Framework name of the TargetPlatformCapabilities.
42
36
  target_platform_name: Target platform model name the model will use for inference.
43
37
  target_platform_version: Target platform capabilities version.
38
+
44
39
  Returns:
45
- A TargetPlatformCapabilities object that models the hardware and attaches
46
- a framework information to it.
40
+ A default TargetPlatformModel object.
47
41
  """
48
- assert target_platform_name in tpc_dict, f'Target platform {target_platform_name} is not defined!'
49
- fw_tpc = tpc_dict.get(target_platform_name)
50
- tpc_versions = fw_tpc(fw_name)
51
- if target_platform_version is None:
52
- target_platform_version = LATEST
53
- else:
54
- assert target_platform_version in tpc_versions, (f'TPC version {target_platform_version} is not supported for '
55
- f'framework {fw_name}.')
56
- return tpc_versions[target_platform_version]()
42
+
43
+ assert fw_name in [TENSORFLOW, PYTORCH], f"Unsupported framework {fw_name}."
44
+
45
+ if target_platform_name == DEFAULT_TP_MODEL:
46
+ return get_tp_model_imx500_v1()
47
+
48
+ assert target_platform_version == 'v1' or target_platform_version is None, \
49
+ "The usage of get_target_platform_capabilities API is supported only with the default TPC ('v1')."
50
+
51
+ if target_platform_name == IMX500_TP_MODEL:
52
+ return get_tp_model_imx500_v1()
53
+ elif target_platform_name == TFLITE_TP_MODEL:
54
+ return get_tp_model_tflite_v1()
55
+ elif target_platform_name == QNNPACK_TP_MODEL:
56
+ return get_tp_model_qnnpack_v1()
57
+
58
+ raise ValueError(f"Unsupported target platform name {target_platform_name}.")
59
+
60
+
61
+ def get_tpc_model(name: str, tp_model: TargetPlatformModel):
62
+ """
63
+ This is a utility method that just returns the TargetPlatformModel that it receives, to support existing TPC API.
64
+
65
+ Args:
66
+ name: the name of the TargetPlatformModel (not used in this function).
67
+ tp_model: a TargetPlatformModel to return.
68
+
69
+ Returns:
70
+ The given TargetPlatformModel object.
71
+
72
+ """
73
+
74
+ return tp_model
@@ -16,9 +16,10 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
16
16
  from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
17
17
  get_op_quantization_configs
18
18
  if FOUND_TF:
19
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
20
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_keras_tpc_latest
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
21
+ get_tpc_model as generate_keras_tpc
21
22
  if FOUND_TORCH:
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import get_pytorch_tpc as \
23
- get_pytorch_tpc_latest
24
- from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import generate_pytorch_tpc
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model as get_pytorch_tpc_latest
24
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
25
+ get_tpc_model as generate_pytorch_tpc
@@ -167,38 +167,70 @@ def generate_tp_model(default_config: OpQuantizationConfig,
167
167
  operator_set = []
168
168
  fusing_patterns = []
169
169
 
170
- operator_set.append(schema.OperatorsSet(name="NoQuantization",
171
- qc_options=default_configuration_options.clone_and_edit(enable_activation_quantization=False)
172
- .clone_and_edit_weight_attribute(enable_weights_quantization=False)))
170
+ no_quantization_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False)
171
+ .clone_and_edit_weight_attribute(enable_weights_quantization=False))
172
+
173
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.STACK, qc_options=no_quantization_config))
174
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.UNSTACK, qc_options=no_quantization_config))
175
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.DROPOUT, qc_options=no_quantization_config))
176
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FLATTEN, qc_options=no_quantization_config))
177
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SPLIT_CHUNK, qc_options=no_quantization_config))
178
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.GET_ITEM, qc_options=no_quantization_config))
179
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.RESHAPE, qc_options=no_quantization_config))
180
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.UNSQUEEZE, qc_options=no_quantization_config))
181
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SIZE, qc_options=no_quantization_config))
182
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.PERMUTE, qc_options=no_quantization_config))
183
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TRANSPOSE, qc_options=no_quantization_config))
184
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.EQUAL, qc_options=no_quantization_config))
185
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.ARGMAX, qc_options=no_quantization_config))
186
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.GATHER, qc_options=no_quantization_config))
187
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TOPK, qc_options=no_quantization_config))
188
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SQUEEZE, qc_options=no_quantization_config))
189
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.MAXPOOL, qc_options=no_quantization_config))
190
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.PAD, qc_options=no_quantization_config))
191
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.ZERO_PADDING2D, qc_options=no_quantization_config))
192
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.CAST, qc_options=no_quantization_config))
193
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION, qc_options=no_quantization_config))
194
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.FAKE_QUANT, qc_options=no_quantization_config))
195
+ operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.SSD_POST_PROCESS, qc_options=no_quantization_config))
173
196
 
174
197
  # Define operator sets that use mixed_precision_configuration_options:
175
- conv = schema.OperatorsSet(name="Conv", qc_options=mixed_precision_configuration_options)
176
- fc = schema.OperatorsSet(name="FullyConnected", qc_options=mixed_precision_configuration_options)
198
+ conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options)
199
+ conv_transpose = schema.OperatorsSet(name=schema.OperatorSetNames.CONV_TRANSPOSE, qc_options=mixed_precision_configuration_options)
200
+ depthwise_conv = schema.OperatorsSet(name=schema.OperatorSetNames.DEPTHWISE_CONV, qc_options=mixed_precision_configuration_options)
201
+ fc = schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED, qc_options=mixed_precision_configuration_options)
177
202
 
178
203
  # Define operations sets without quantization configuration
179
204
  # options (useful for creating fusing patterns, for example):
180
- any_relu = schema.OperatorsSet(name="AnyReLU")
181
- add = schema.OperatorsSet(name="Add")
182
- sub = schema.OperatorsSet(name="Sub")
183
- mul = schema.OperatorsSet(name="Mul")
184
- div = schema.OperatorsSet(name="Div")
185
- prelu = schema.OperatorsSet(name="PReLU")
186
- swish = schema.OperatorsSet(name="Swish")
187
- sigmoid = schema.OperatorsSet(name="Sigmoid")
188
- tanh = schema.OperatorsSet(name="Tanh")
189
-
190
- operator_set.extend([conv, fc, any_relu, add, sub, mul, div, prelu, swish, sigmoid, tanh])
205
+ relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU)
206
+ relu6 = schema.OperatorsSet(name=schema.OperatorSetNames.RELU6)
207
+ leaky_relu = schema.OperatorsSet(name=schema.OperatorSetNames.LEAKY_RELU)
208
+ prelu = schema.OperatorsSet(name=schema.OperatorSetNames.PRELU)
209
+ add = schema.OperatorsSet(name=schema.OperatorSetNames.ADD)
210
+ sub = schema.OperatorsSet(name=schema.OperatorSetNames.SUB)
211
+ mul = schema.OperatorsSet(name=schema.OperatorSetNames.MUL)
212
+ div = schema.OperatorsSet(name=schema.OperatorSetNames.DIV)
213
+ swish = schema.OperatorsSet(name=schema.OperatorSetNames.SWISH)
214
+ hard_swish = schema.OperatorsSet(name=schema.OperatorSetNames.HARDSWISH)
215
+ sigmoid = schema.OperatorsSet(name=schema.OperatorSetNames.SIGMOID)
216
+ tanh = schema.OperatorsSet(name=schema.OperatorSetNames.TANH)
217
+ hard_tanh = schema.OperatorsSet(name=schema.OperatorSetNames.HARD_TANH)
218
+
219
+ operator_set.extend([conv, conv_transpose, depthwise_conv, fc, relu, relu6, leaky_relu, add, sub, mul, div, prelu, swish,
220
+ hard_swish, sigmoid, tanh, hard_tanh])
221
+ any_relu = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh])
191
222
  # Combine multiple operators into a single operator to avoid quantization between
192
223
  # them. To do this we define fusing patterns using the OperatorsSets that were created.
193
224
  # To group multiple sets with regard to fusing, an OperatorSetConcat can be created
194
- activations_after_conv_to_fuse = schema.OperatorSetConcat(operators_set=[any_relu, swish, prelu, sigmoid, tanh])
195
- activations_after_fc_to_fuse = schema.OperatorSetConcat(operators_set=[any_relu, swish, sigmoid])
225
+ activations_after_conv_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, prelu, sigmoid, tanh])
226
+ conv_types = schema.OperatorSetConcat(operators_set=[conv, conv_transpose, depthwise_conv])
227
+ activations_after_fc_to_fuse = schema.OperatorSetConcat(operators_set=[relu, relu6, leaky_relu, hard_tanh, swish, hard_swish, sigmoid])
196
228
  any_binary = schema.OperatorSetConcat(operators_set=[add, sub, mul, div])
197
229
 
198
230
  # ------------------- #
199
231
  # Fusions
200
232
  # ------------------- #
201
- fusing_patterns.append(schema.Fusing(operator_groups=(conv, activations_after_conv_to_fuse)))
233
+ fusing_patterns.append(schema.Fusing(operator_groups=(conv_types, activations_after_conv_to_fuse)))
202
234
  fusing_patterns.append(schema.Fusing(operator_groups=(fc, activations_after_fc_to_fuse)))
203
235
  fusing_patterns.append(schema.Fusing(operator_groups=(any_binary, any_relu)))
204
236
 
@@ -15,8 +15,12 @@
15
15
  from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
16
16
  from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
17
17
  if FOUND_TF:
18
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
19
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_keras import generate_keras_tpc
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
19
+ get_keras_tpc_latest
20
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
21
+ get_tpc_model as generate_keras_tpc, get_tpc_model as generate_keras_tpc
20
22
  if FOUND_TORCH:
21
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tpc_pytorch import generate_pytorch_tpc
23
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.qnnpack_tpc.v1.tp_model import get_tp_model as \
24
+ get_pytorch_tpc_latest
25
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
26
+ get_tpc_model as generate_pytorch_tpc, get_tpc_model as generate_pytorch_tpc
@@ -148,19 +148,29 @@ def generate_tp_model(default_config: OpQuantizationConfig,
148
148
  operator_set = []
149
149
  fusing_patterns = []
150
150
 
151
- conv = schema.OperatorsSet(name="Conv")
152
- batchnorm = schema.OperatorsSet(name="BatchNorm")
153
- relu = schema.OperatorsSet(name="Relu")
154
- linear = schema.OperatorsSet(name="Linear")
151
+ conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV)
152
+ conv_depthwise = schema.OperatorsSet(name=schema.OperatorSetNames.DEPTHWISE_CONV)
153
+ conv_transpose = schema.OperatorsSet(name=schema.OperatorSetNames.CONV_TRANSPOSE)
154
+ batchnorm = schema.OperatorsSet(name=schema.OperatorSetNames.BATCH_NORM)
155
+ relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU)
156
+ relu6 = schema.OperatorsSet(name=schema.OperatorSetNames.RELU6)
157
+
158
+ hard_tanh = schema.OperatorsSet(name=schema.OperatorSetNames.HARD_TANH)
159
+ linear = schema.OperatorsSet(name=schema.OperatorSetNames.FULLY_CONNECTED)
160
+
161
+ operator_set.extend([conv, conv_depthwise, conv_transpose, batchnorm, relu, relu6, hard_tanh, linear])
162
+
163
+ conv_opset_concat = schema.OperatorSetConcat(operators_set=[conv, conv_transpose])
164
+ relu_opset_concat = schema.OperatorSetConcat(operators_set=[relu, relu6, hard_tanh])
155
165
 
156
- operator_set.extend([conv, batchnorm, relu, linear])
157
166
  # ------------------- #
158
167
  # Fusions
159
168
  # ------------------- #
160
- fusing_patterns.append(schema.Fusing(operator_groups=(conv, batchnorm, relu)))
161
- fusing_patterns.append(schema.Fusing(operator_groups=(conv, batchnorm)))
162
- fusing_patterns.append(schema.Fusing(operator_groups=(conv, relu)))
163
- fusing_patterns.append(schema.Fusing(operator_groups=(linear, relu)))
169
+ fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, batchnorm, relu_opset_concat)))
170
+ fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, batchnorm)))
171
+ fusing_patterns.append(schema.Fusing(operator_groups=(conv_opset_concat, relu_opset_concat)))
172
+ fusing_patterns.append(schema.Fusing(operator_groups=(linear, relu_opset_concat)))
173
+
164
174
  # Create a TargetPlatformModel and set its default quantization config.
165
175
  # This default configuration will be used for all operations
166
176
  # unless specified otherwise (see OperatorsSet, for example):
@@ -15,8 +15,11 @@
15
15
  from model_compression_toolkit.verify_packages import FOUND_TORCH, FOUND_TF
16
16
  from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_tp_model, generate_tp_model, get_op_quantization_configs
17
17
  if FOUND_TF:
18
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
19
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_keras import generate_keras_tpc
18
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import get_keras_tpc as get_keras_tpc_latest
19
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
20
+ get_tpc_model as generate_keras_tpc
20
21
  if FOUND_TORCH:
21
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import get_pytorch_tpc as get_pytorch_tpc_latest
22
- from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tpc_pytorch import generate_pytorch_tpc
22
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.tflite_tpc.v1.tp_model import \
23
+ get_tp_model as get_pytorch_tpc_latest
24
+ from model_compression_toolkit.target_platform_capabilities.tpc_models.get_target_platform_capabilities import \
25
+ get_tpc_model as generate_pytorch_tpc