mct-nightly 2.2.0.20241201.617__py3-none-any.whl → 2.2.0.20241202.131715__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/RECORD +58 -58
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +0 -3
  5. model_compression_toolkit/core/common/graph/base_node.py +7 -5
  6. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  7. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -2
  8. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -2
  9. model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -1
  11. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
  12. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -1
  13. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
  14. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
  15. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
  16. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -1
  17. model_compression_toolkit/metadata.py +14 -5
  18. model_compression_toolkit/target_platform_capabilities/schema/__init__.py +14 -0
  19. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +11 -0
  20. model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +37 -0
  21. model_compression_toolkit/target_platform_capabilities/{target_platform/op_quantization_config.py → schema/v1.py} +377 -24
  22. model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +3 -5
  23. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -214
  24. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -2
  25. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +6 -10
  26. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +39 -32
  27. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +3 -2
  28. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +3 -5
  29. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +36 -31
  30. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +3 -2
  31. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +3 -4
  32. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +37 -32
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +3 -2
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +3 -4
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +39 -32
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +3 -2
  37. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +3 -4
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +36 -31
  39. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +3 -2
  40. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +3 -4
  41. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +45 -38
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +3 -2
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -4
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +37 -32
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +3 -2
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -4
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +70 -62
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +3 -2
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +3 -4
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +22 -17
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +3 -4
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +3 -4
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +56 -51
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -4
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -4
  56. model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +0 -85
  57. model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +0 -87
  58. model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model_component.py +0 -40
  59. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/LICENSE.md +0 -0
  60. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/WHEEL +0 -0
  61. {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,11 @@
15
15
  from typing import List, Tuple
16
16
 
17
17
  import model_compression_toolkit as mct
18
+ import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
18
19
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
19
- from model_compression_toolkit.target_platform_capabilities.constants import BIAS_ATTR, KERNEL_ATTR
20
- from model_compression_toolkit.target_platform_capabilities.target_platform import OpQuantizationConfig, \
21
- TargetPlatformModel, Signedness
22
- from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
23
- QuantizationMethod, AttributeQuantizationConfig
20
+ from model_compression_toolkit.target_platform_capabilities.constants import BIAS_ATTR, KERNEL_ATTR, TFLITE_TP_MODEL
21
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, Signedness, \
22
+ AttributeQuantizationConfig, OpQuantizationConfig
24
23
 
25
24
  tp = mct.target_platform
26
25
 
@@ -83,7 +82,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
83
82
 
84
83
  # We define a default config for operation without kernel attribute.
85
84
  # This is the default config that should be used for non-linear operations.
86
- eight_bits_default = tp.OpQuantizationConfig(
85
+ eight_bits_default = schema.OpQuantizationConfig(
87
86
  default_weight_attr_config=default_weight_attr_config,
88
87
  attr_weights_configs_mapping={},
89
88
  activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
@@ -97,8 +96,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
97
96
  signedness=Signedness.AUTO)
98
97
 
99
98
  # We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes.
100
- linear_eight_bits = tp.OpQuantizationConfig(
101
- activation_quantization_method=QuantizationMethod.UNIFORM,
99
+ linear_eight_bits = schema.OpQuantizationConfig(
100
+ activation_quantization_method=tp.QuantizationMethod.UNIFORM,
102
101
  default_weight_attr_config=default_weight_attr_config,
103
102
  attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config},
104
103
  activation_n_bits=8,
@@ -137,12 +136,18 @@ def generate_tp_model(default_config: OpQuantizationConfig,
137
136
  # of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example).
138
137
  # If the QuantizationConfigOptions contains only one configuration,
139
138
  # this configuration will be used for the operation quantization:
140
- default_configuration_options = tp.QuantizationConfigOptions([default_config])
139
+ default_configuration_options = schema.QuantizationConfigOptions([default_config])
141
140
 
142
141
  # Create a TargetPlatformModel and set its default quantization config.
143
142
  # This default configuration will be used for all operations
144
143
  # unless specified otherwise (see OperatorsSet, for example):
145
- generated_tpc = tp.TargetPlatformModel(default_configuration_options, name=name)
144
+ generated_tpc = schema.TargetPlatformModel(
145
+ default_configuration_options,
146
+ tpc_minor_version=1,
147
+ tpc_patch_version=0,
148
+ tpc_platform_type=TFLITE_TP_MODEL,
149
+ add_metadata=False,
150
+ name=name)
146
151
 
147
152
  # To start defining the model's components (such as operator sets, and fusing patterns),
148
153
  # use 'with' the TargetPlatformModel instance, and create them as below:
@@ -150,52 +155,52 @@ def generate_tp_model(default_config: OpQuantizationConfig,
150
155
  # In TFLite, the quantized operator specifications constraint operators quantization
151
156
  # differently. For more details:
152
157
  # https://www.tensorflow.org/lite/performance/quantization_spec#int8_quantized_operator_specifications
153
- tp.OperatorsSet("NoQuantization",
154
- tp.get_default_quantization_config_options().clone_and_edit(
155
- quantization_preserving=True))
158
+ schema.OperatorsSet("NoQuantization",
159
+ tp.get_default_quantization_config_options().clone_and_edit(
160
+ quantization_preserving=True))
156
161
 
157
162
  fc_qco = tp.get_default_quantization_config_options()
158
- fc = tp.OperatorsSet("FullyConnected",
159
- fc_qco.clone_and_edit_weight_attribute(weights_per_channel_threshold=False))
160
-
161
- tp.OperatorsSet("L2Normalization",
162
- tp.get_default_quantization_config_options().clone_and_edit(
163
- fixed_zero_point=0, fixed_scale=1 / 128))
164
- tp.OperatorsSet("LogSoftmax",
165
- tp.get_default_quantization_config_options().clone_and_edit(
166
- fixed_zero_point=127, fixed_scale=16 / 256))
167
- tp.OperatorsSet("Tanh",
168
- tp.get_default_quantization_config_options().clone_and_edit(
169
- fixed_zero_point=0, fixed_scale=1 / 128))
170
- tp.OperatorsSet("Softmax",
171
- tp.get_default_quantization_config_options().clone_and_edit(
172
- fixed_zero_point=-128, fixed_scale=1 / 256))
173
- tp.OperatorsSet("Logistic",
174
- tp.get_default_quantization_config_options().clone_and_edit(
175
- fixed_zero_point=-128, fixed_scale=1 / 256))
176
-
177
- conv2d = tp.OperatorsSet("Conv2d")
178
- kernel = tp.OperatorSetConcat(conv2d, fc)
179
-
180
- relu = tp.OperatorsSet("Relu")
181
- elu = tp.OperatorsSet("Elu")
182
- activations_to_fuse = tp.OperatorSetConcat(relu, elu)
183
-
184
- batch_norm = tp.OperatorsSet("BatchNorm")
185
- bias_add = tp.OperatorsSet("BiasAdd")
186
- add = tp.OperatorsSet("Add")
187
- squeeze = tp.OperatorsSet("Squeeze",
188
- qc_options=tp.get_default_quantization_config_options().clone_and_edit(
189
- quantization_preserving=True))
163
+ fc = schema.OperatorsSet("FullyConnected",
164
+ fc_qco.clone_and_edit_weight_attribute(weights_per_channel_threshold=False))
165
+
166
+ schema.OperatorsSet("L2Normalization",
167
+ tp.get_default_quantization_config_options().clone_and_edit(
168
+ fixed_zero_point=0, fixed_scale=1 / 128))
169
+ schema.OperatorsSet("LogSoftmax",
170
+ tp.get_default_quantization_config_options().clone_and_edit(
171
+ fixed_zero_point=127, fixed_scale=16 / 256))
172
+ schema.OperatorsSet("Tanh",
173
+ tp.get_default_quantization_config_options().clone_and_edit(
174
+ fixed_zero_point=0, fixed_scale=1 / 128))
175
+ schema.OperatorsSet("Softmax",
176
+ tp.get_default_quantization_config_options().clone_and_edit(
177
+ fixed_zero_point=-128, fixed_scale=1 / 256))
178
+ schema.OperatorsSet("Logistic",
179
+ tp.get_default_quantization_config_options().clone_and_edit(
180
+ fixed_zero_point=-128, fixed_scale=1 / 256))
181
+
182
+ conv2d = schema.OperatorsSet("Conv2d")
183
+ kernel = schema.OperatorSetConcat(conv2d, fc)
184
+
185
+ relu = schema.OperatorsSet("Relu")
186
+ elu = schema.OperatorsSet("Elu")
187
+ activations_to_fuse = schema.OperatorSetConcat(relu, elu)
188
+
189
+ batch_norm = schema.OperatorsSet("BatchNorm")
190
+ bias_add = schema.OperatorsSet("BiasAdd")
191
+ add = schema.OperatorsSet("Add")
192
+ squeeze = schema.OperatorsSet("Squeeze",
193
+ qc_options=tp.get_default_quantization_config_options().clone_and_edit(
194
+ quantization_preserving=True))
190
195
  # ------------------- #
191
196
  # Fusions
192
197
  # ------------------- #
193
198
  # Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/remapper
194
- tp.Fusing([kernel, bias_add])
195
- tp.Fusing([kernel, bias_add, activations_to_fuse])
196
- tp.Fusing([conv2d, batch_norm, activations_to_fuse])
197
- tp.Fusing([conv2d, squeeze, activations_to_fuse])
198
- tp.Fusing([batch_norm, activations_to_fuse])
199
- tp.Fusing([batch_norm, add, activations_to_fuse])
199
+ schema.Fusing([kernel, bias_add])
200
+ schema.Fusing([kernel, bias_add, activations_to_fuse])
201
+ schema.Fusing([conv2d, batch_norm, activations_to_fuse])
202
+ schema.Fusing([conv2d, squeeze, activations_to_fuse])
203
+ schema.Fusing([batch_norm, activations_to_fuse])
204
+ schema.Fusing([batch_norm, add, activations_to_fuse])
200
205
 
201
206
  return generated_tpc
@@ -15,6 +15,7 @@
15
15
  import tensorflow as tf
16
16
  from packaging import version
17
17
 
18
+ import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
18
19
  from model_compression_toolkit.defaultdict import DefaultDict
19
20
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS
20
21
 
@@ -46,7 +47,7 @@ def get_keras_tpc() -> tp.TargetPlatformCapabilities:
46
47
  return generate_keras_tpc(name='tflite_keras', tp_model=tflite_tp_model)
47
48
 
48
49
 
49
- def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
50
+ def generate_keras_tpc(name: str, tp_model: schema.TargetPlatformModel):
50
51
  """
51
52
  Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
52
53
 
@@ -57,9 +58,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
57
58
  Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
58
59
  """
59
60
 
60
- keras_tpc = tp.TargetPlatformCapabilities(tp_model,
61
- name=name,
62
- version=TPC_VERSION)
61
+ keras_tpc = tp.TargetPlatformCapabilities(tp_model)
63
62
 
64
63
  with keras_tpc:
65
64
  tp.OperationsSetToLayers("NoQuantization", [AveragePooling2D,
@@ -16,6 +16,7 @@ import torch
16
16
  from torch.nn import AvgPool2d, MaxPool2d
17
17
  from torch.nn.functional import avg_pool2d, max_pool2d, interpolate
18
18
 
19
+ import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
19
20
  from model_compression_toolkit.defaultdict import DefaultDict
20
21
  from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS_ATTR, \
21
22
  BIAS
@@ -37,7 +38,7 @@ def get_pytorch_tpc() -> tp.TargetPlatformCapabilities:
37
38
  return generate_pytorch_tpc(name='tflite_torch', tp_model=tflite_tp_model)
38
39
 
39
40
 
40
- def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
41
+ def generate_pytorch_tpc(name: str, tp_model: schema.TargetPlatformModel):
41
42
  """
42
43
  Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
43
44
  Args:
@@ -46,9 +47,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
46
47
  Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
47
48
  """
48
49
 
49
- pytorch_tpc = tp.TargetPlatformCapabilities(tp_model,
50
- name=name,
51
- version=TPC_VERSION)
50
+ pytorch_tpc = tp.TargetPlatformCapabilities(tp_model)
52
51
 
53
52
  with pytorch_tpc:
54
53
  tp.OperationsSetToLayers("NoQuantization", [AvgPool2d,
@@ -1,85 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from typing import Any, List, Union
18
-
19
- from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
20
- OperatorsSet
21
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
22
-
23
-
24
- class Fusing(TargetPlatformModelComponent):
25
- """
26
- Fusing defines a list of operators that should be combined and treated as a single operator,
27
- hence no quantization is applied between them.
28
- """
29
-
30
- def __init__(self,
31
- operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
32
- name: str = None):
33
- """
34
- Args:
35
- operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
36
- name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
37
- """
38
- assert isinstance(operator_groups_list,
39
- list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
40
- assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
41
-
42
- # Generate a name from the operator groups if no name is provided
43
- if name is None:
44
- name = '_'.join([x.name for x in operator_groups_list])
45
-
46
- super().__init__(name)
47
- self.operator_groups_list = operator_groups_list
48
-
49
- def contains(self, other: Any) -> bool:
50
- """
51
- Determines if the current Fusing instance contains another Fusing instance.
52
-
53
- Args:
54
- other: The other Fusing instance to check against.
55
-
56
- Returns:
57
- A boolean indicating whether the other instance is contained within this one.
58
- """
59
- if not isinstance(other, Fusing):
60
- return False
61
-
62
- # Check for containment by comparing operator groups
63
- for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
64
- for j in range(len(other.operator_groups_list)):
65
- if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
66
- isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
67
- other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
68
- break
69
- else:
70
- # If all checks pass, the other Fusing instance is contained
71
- return True
72
- # Other Fusing instance is not contained
73
- return False
74
-
75
- def get_info(self):
76
- """
77
- Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
78
-
79
- Returns:
80
- A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
81
- or just the sequence of operator groups if no name is set.
82
- """
83
- if self.name is not None:
84
- return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
85
- return ' -> '.join([x.name for x in self.operator_groups_list])
@@ -1,87 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Dict, Any
16
-
17
- from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
18
- from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
19
- from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
20
- from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions
21
-
22
-
23
- class OperatorsSetBase(TargetPlatformModelComponent):
24
- """
25
- Base class to represent a set of operators.
26
- """
27
- def __init__(self, name: str):
28
- """
29
-
30
- Args:
31
- name: Name of OperatorsSet.
32
- """
33
- super().__init__(name=name)
34
-
35
-
36
- class OperatorsSet(OperatorsSetBase):
37
- def __init__(self,
38
- name: str,
39
- qc_options: QuantizationConfigOptions = None):
40
- """
41
- Set of operators that are represented by a unique label.
42
-
43
- Args:
44
- name (str): Set's label (must be unique in a TargetPlatformModel).
45
- qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
46
- """
47
-
48
- super().__init__(name)
49
- self.qc_options = qc_options
50
- is_fusing_set = qc_options is None
51
- self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
52
-
53
-
54
- def get_info(self) -> Dict[str,Any]:
55
- """
56
-
57
- Returns: Info about the set as a dictionary.
58
-
59
- """
60
- return {"name": self.name,
61
- "is_default_qc": self.is_default}
62
-
63
-
64
- class OperatorSetConcat(OperatorsSetBase):
65
- """
66
- Concatenate a list of operator sets to treat them similarly in different places (like fusing).
67
- """
68
- def __init__(self, *opsets: OperatorsSet):
69
- """
70
- Group a list of operation sets.
71
-
72
- Args:
73
- *opsets (OperatorsSet): List of operator sets to group.
74
- """
75
- name = "_".join([a.name for a in opsets])
76
- super().__init__(name=name)
77
- self.op_set_list = opsets
78
- self.qc_options = None # Concat have no qc options
79
-
80
- def get_info(self) -> Dict[str,Any]:
81
- """
82
-
83
- Returns: Info about the sets group as a dictionary.
84
-
85
- """
86
- return {"name": self.name,
87
- OPS_SET_LIST: [s.name for s in self.op_set_list]}
@@ -1,40 +0,0 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from typing import Any, Dict
16
-
17
- from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
18
-
19
-
20
- class TargetPlatformModelComponent:
21
- """
22
- Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
23
- """
24
- def __init__(self, name: str):
25
- """
26
-
27
- Args:
28
- name: Name of component.
29
- """
30
- self.name = name
31
- _current_tp_model.get().append_component(self)
32
-
33
- def get_info(self) -> Dict[str, Any]:
34
- """
35
-
36
- Returns: Get information about the component to display (return an empty dictionary.
37
- the actual component should fill it with info).
38
-
39
- """
40
- return {}