mct-nightly 2.2.0.20241201.617__py3-none-any.whl → 2.2.0.20241202.131715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/RECORD +58 -58
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +0 -3
- model_compression_toolkit/core/common/graph/base_node.py +7 -5
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -1
- model_compression_toolkit/metadata.py +14 -5
- model_compression_toolkit/target_platform_capabilities/schema/__init__.py +14 -0
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +11 -0
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +37 -0
- model_compression_toolkit/target_platform_capabilities/{target_platform/op_quantization_config.py → schema/v1.py} +377 -24
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +3 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -214
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +6 -10
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +39 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +3 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +36 -31
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +37 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +39 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +36 -31
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +45 -38
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +37 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +70 -62
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +22 -17
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +56 -51
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +0 -85
- model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +0 -87
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model_component.py +0 -40
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,11 @@
|
|
15
15
|
from typing import List, Tuple
|
16
16
|
|
17
17
|
import model_compression_toolkit as mct
|
18
|
+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
|
18
19
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.constants import BIAS_ATTR, KERNEL_ATTR
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.
|
21
|
-
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import \
|
23
|
-
QuantizationMethod, AttributeQuantizationConfig
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.constants import BIAS_ATTR, KERNEL_ATTR, TFLITE_TP_MODEL
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, Signedness, \
|
22
|
+
AttributeQuantizationConfig, OpQuantizationConfig
|
24
23
|
|
25
24
|
tp = mct.target_platform
|
26
25
|
|
@@ -83,7 +82,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
|
|
83
82
|
|
84
83
|
# We define a default config for operation without kernel attribute.
|
85
84
|
# This is the default config that should be used for non-linear operations.
|
86
|
-
eight_bits_default =
|
85
|
+
eight_bits_default = schema.OpQuantizationConfig(
|
87
86
|
default_weight_attr_config=default_weight_attr_config,
|
88
87
|
attr_weights_configs_mapping={},
|
89
88
|
activation_quantization_method=tp.QuantizationMethod.POWER_OF_TWO,
|
@@ -97,8 +96,8 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
|
|
97
96
|
signedness=Signedness.AUTO)
|
98
97
|
|
99
98
|
# We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes.
|
100
|
-
linear_eight_bits =
|
101
|
-
activation_quantization_method=QuantizationMethod.UNIFORM,
|
99
|
+
linear_eight_bits = schema.OpQuantizationConfig(
|
100
|
+
activation_quantization_method=tp.QuantizationMethod.UNIFORM,
|
102
101
|
default_weight_attr_config=default_weight_attr_config,
|
103
102
|
attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config, BIAS_ATTR: bias_config},
|
104
103
|
activation_n_bits=8,
|
@@ -137,12 +136,18 @@ def generate_tp_model(default_config: OpQuantizationConfig,
|
|
137
136
|
# of possible configurations to consider when quantizing a set of operations (in mixed-precision, for example).
|
138
137
|
# If the QuantizationConfigOptions contains only one configuration,
|
139
138
|
# this configuration will be used for the operation quantization:
|
140
|
-
default_configuration_options =
|
139
|
+
default_configuration_options = schema.QuantizationConfigOptions([default_config])
|
141
140
|
|
142
141
|
# Create a TargetPlatformModel and set its default quantization config.
|
143
142
|
# This default configuration will be used for all operations
|
144
143
|
# unless specified otherwise (see OperatorsSet, for example):
|
145
|
-
generated_tpc =
|
144
|
+
generated_tpc = schema.TargetPlatformModel(
|
145
|
+
default_configuration_options,
|
146
|
+
tpc_minor_version=1,
|
147
|
+
tpc_patch_version=0,
|
148
|
+
tpc_platform_type=TFLITE_TP_MODEL,
|
149
|
+
add_metadata=False,
|
150
|
+
name=name)
|
146
151
|
|
147
152
|
# To start defining the model's components (such as operator sets, and fusing patterns),
|
148
153
|
# use 'with' the TargetPlatformModel instance, and create them as below:
|
@@ -150,52 +155,52 @@ def generate_tp_model(default_config: OpQuantizationConfig,
|
|
150
155
|
# In TFLite, the quantized operator specifications constraint operators quantization
|
151
156
|
# differently. For more details:
|
152
157
|
# https://www.tensorflow.org/lite/performance/quantization_spec#int8_quantized_operator_specifications
|
153
|
-
|
154
|
-
|
155
|
-
|
158
|
+
schema.OperatorsSet("NoQuantization",
|
159
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
160
|
+
quantization_preserving=True))
|
156
161
|
|
157
162
|
fc_qco = tp.get_default_quantization_config_options()
|
158
|
-
fc =
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
conv2d =
|
178
|
-
kernel =
|
179
|
-
|
180
|
-
relu =
|
181
|
-
elu =
|
182
|
-
activations_to_fuse =
|
183
|
-
|
184
|
-
batch_norm =
|
185
|
-
bias_add =
|
186
|
-
add =
|
187
|
-
squeeze =
|
188
|
-
|
189
|
-
|
163
|
+
fc = schema.OperatorsSet("FullyConnected",
|
164
|
+
fc_qco.clone_and_edit_weight_attribute(weights_per_channel_threshold=False))
|
165
|
+
|
166
|
+
schema.OperatorsSet("L2Normalization",
|
167
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
168
|
+
fixed_zero_point=0, fixed_scale=1 / 128))
|
169
|
+
schema.OperatorsSet("LogSoftmax",
|
170
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
171
|
+
fixed_zero_point=127, fixed_scale=16 / 256))
|
172
|
+
schema.OperatorsSet("Tanh",
|
173
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
174
|
+
fixed_zero_point=0, fixed_scale=1 / 128))
|
175
|
+
schema.OperatorsSet("Softmax",
|
176
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
177
|
+
fixed_zero_point=-128, fixed_scale=1 / 256))
|
178
|
+
schema.OperatorsSet("Logistic",
|
179
|
+
tp.get_default_quantization_config_options().clone_and_edit(
|
180
|
+
fixed_zero_point=-128, fixed_scale=1 / 256))
|
181
|
+
|
182
|
+
conv2d = schema.OperatorsSet("Conv2d")
|
183
|
+
kernel = schema.OperatorSetConcat(conv2d, fc)
|
184
|
+
|
185
|
+
relu = schema.OperatorsSet("Relu")
|
186
|
+
elu = schema.OperatorsSet("Elu")
|
187
|
+
activations_to_fuse = schema.OperatorSetConcat(relu, elu)
|
188
|
+
|
189
|
+
batch_norm = schema.OperatorsSet("BatchNorm")
|
190
|
+
bias_add = schema.OperatorsSet("BiasAdd")
|
191
|
+
add = schema.OperatorsSet("Add")
|
192
|
+
squeeze = schema.OperatorsSet("Squeeze",
|
193
|
+
qc_options=tp.get_default_quantization_config_options().clone_and_edit(
|
194
|
+
quantization_preserving=True))
|
190
195
|
# ------------------- #
|
191
196
|
# Fusions
|
192
197
|
# ------------------- #
|
193
198
|
# Source: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/remapper
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
199
|
+
schema.Fusing([kernel, bias_add])
|
200
|
+
schema.Fusing([kernel, bias_add, activations_to_fuse])
|
201
|
+
schema.Fusing([conv2d, batch_norm, activations_to_fuse])
|
202
|
+
schema.Fusing([conv2d, squeeze, activations_to_fuse])
|
203
|
+
schema.Fusing([batch_norm, activations_to_fuse])
|
204
|
+
schema.Fusing([batch_norm, add, activations_to_fuse])
|
200
205
|
|
201
206
|
return generated_tpc
|
model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from packaging import version
|
17
17
|
|
18
|
+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
|
18
19
|
from model_compression_toolkit.defaultdict import DefaultDict
|
19
20
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_KERNEL, BIAS_ATTR, BIAS
|
20
21
|
|
@@ -46,7 +47,7 @@ def get_keras_tpc() -> tp.TargetPlatformCapabilities:
|
|
46
47
|
return generate_keras_tpc(name='tflite_keras', tp_model=tflite_tp_model)
|
47
48
|
|
48
49
|
|
49
|
-
def generate_keras_tpc(name: str, tp_model:
|
50
|
+
def generate_keras_tpc(name: str, tp_model: schema.TargetPlatformModel):
|
50
51
|
"""
|
51
52
|
Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
|
52
53
|
|
@@ -57,9 +58,7 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel):
|
|
57
58
|
Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
|
58
59
|
"""
|
59
60
|
|
60
|
-
keras_tpc = tp.TargetPlatformCapabilities(tp_model
|
61
|
-
name=name,
|
62
|
-
version=TPC_VERSION)
|
61
|
+
keras_tpc = tp.TargetPlatformCapabilities(tp_model)
|
63
62
|
|
64
63
|
with keras_tpc:
|
65
64
|
tp.OperationsSetToLayers("NoQuantization", [AveragePooling2D,
|
model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py
CHANGED
@@ -16,6 +16,7 @@ import torch
|
|
16
16
|
from torch.nn import AvgPool2d, MaxPool2d
|
17
17
|
from torch.nn.functional import avg_pool2d, max_pool2d, interpolate
|
18
18
|
|
19
|
+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
|
19
20
|
from model_compression_toolkit.defaultdict import DefaultDict
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, PYTORCH_KERNEL, BIAS_ATTR, \
|
21
22
|
BIAS
|
@@ -37,7 +38,7 @@ def get_pytorch_tpc() -> tp.TargetPlatformCapabilities:
|
|
37
38
|
return generate_pytorch_tpc(name='tflite_torch', tp_model=tflite_tp_model)
|
38
39
|
|
39
40
|
|
40
|
-
def generate_pytorch_tpc(name: str, tp_model:
|
41
|
+
def generate_pytorch_tpc(name: str, tp_model: schema.TargetPlatformModel):
|
41
42
|
"""
|
42
43
|
Generates a TargetPlatformCapabilities object with default operation sets to layers mapping.
|
43
44
|
Args:
|
@@ -46,9 +47,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
|
|
46
47
|
Returns: a TargetPlatformCapabilities object for the given TargetPlatformModel.
|
47
48
|
"""
|
48
49
|
|
49
|
-
pytorch_tpc = tp.TargetPlatformCapabilities(tp_model
|
50
|
-
name=name,
|
51
|
-
version=TPC_VERSION)
|
50
|
+
pytorch_tpc = tp.TargetPlatformCapabilities(tp_model)
|
52
51
|
|
53
52
|
with pytorch_tpc:
|
54
53
|
tp.OperationsSetToLayers("NoQuantization", [AvgPool2d,
|
@@ -1,85 +0,0 @@
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from typing import Any, List, Union
|
18
|
-
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorSetConcat, \
|
20
|
-
OperatorsSet
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
|
22
|
-
|
23
|
-
|
24
|
-
class Fusing(TargetPlatformModelComponent):
|
25
|
-
"""
|
26
|
-
Fusing defines a list of operators that should be combined and treated as a single operator,
|
27
|
-
hence no quantization is applied between them.
|
28
|
-
"""
|
29
|
-
|
30
|
-
def __init__(self,
|
31
|
-
operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
|
32
|
-
name: str = None):
|
33
|
-
"""
|
34
|
-
Args:
|
35
|
-
operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
|
36
|
-
name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
|
37
|
-
"""
|
38
|
-
assert isinstance(operator_groups_list,
|
39
|
-
list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
|
40
|
-
assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
|
41
|
-
|
42
|
-
# Generate a name from the operator groups if no name is provided
|
43
|
-
if name is None:
|
44
|
-
name = '_'.join([x.name for x in operator_groups_list])
|
45
|
-
|
46
|
-
super().__init__(name)
|
47
|
-
self.operator_groups_list = operator_groups_list
|
48
|
-
|
49
|
-
def contains(self, other: Any) -> bool:
|
50
|
-
"""
|
51
|
-
Determines if the current Fusing instance contains another Fusing instance.
|
52
|
-
|
53
|
-
Args:
|
54
|
-
other: The other Fusing instance to check against.
|
55
|
-
|
56
|
-
Returns:
|
57
|
-
A boolean indicating whether the other instance is contained within this one.
|
58
|
-
"""
|
59
|
-
if not isinstance(other, Fusing):
|
60
|
-
return False
|
61
|
-
|
62
|
-
# Check for containment by comparing operator groups
|
63
|
-
for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
|
64
|
-
for j in range(len(other.operator_groups_list)):
|
65
|
-
if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
|
66
|
-
isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
|
67
|
-
other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
|
68
|
-
break
|
69
|
-
else:
|
70
|
-
# If all checks pass, the other Fusing instance is contained
|
71
|
-
return True
|
72
|
-
# Other Fusing instance is not contained
|
73
|
-
return False
|
74
|
-
|
75
|
-
def get_info(self):
|
76
|
-
"""
|
77
|
-
Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
|
78
|
-
|
79
|
-
Returns:
|
80
|
-
A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
|
81
|
-
or just the sequence of operator groups if no name is set.
|
82
|
-
"""
|
83
|
-
if self.name is not None:
|
84
|
-
return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
|
85
|
-
return ' -> '.join([x.name for x in self.operator_groups_list])
|
@@ -1,87 +0,0 @@
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
from typing import Dict, Any
|
16
|
-
|
17
|
-
from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import TargetPlatformModelComponent
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import QuantizationConfigOptions
|
21
|
-
|
22
|
-
|
23
|
-
class OperatorsSetBase(TargetPlatformModelComponent):
|
24
|
-
"""
|
25
|
-
Base class to represent a set of operators.
|
26
|
-
"""
|
27
|
-
def __init__(self, name: str):
|
28
|
-
"""
|
29
|
-
|
30
|
-
Args:
|
31
|
-
name: Name of OperatorsSet.
|
32
|
-
"""
|
33
|
-
super().__init__(name=name)
|
34
|
-
|
35
|
-
|
36
|
-
class OperatorsSet(OperatorsSetBase):
|
37
|
-
def __init__(self,
|
38
|
-
name: str,
|
39
|
-
qc_options: QuantizationConfigOptions = None):
|
40
|
-
"""
|
41
|
-
Set of operators that are represented by a unique label.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
name (str): Set's label (must be unique in a TargetPlatformModel).
|
45
|
-
qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
|
46
|
-
"""
|
47
|
-
|
48
|
-
super().__init__(name)
|
49
|
-
self.qc_options = qc_options
|
50
|
-
is_fusing_set = qc_options is None
|
51
|
-
self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
|
52
|
-
|
53
|
-
|
54
|
-
def get_info(self) -> Dict[str,Any]:
|
55
|
-
"""
|
56
|
-
|
57
|
-
Returns: Info about the set as a dictionary.
|
58
|
-
|
59
|
-
"""
|
60
|
-
return {"name": self.name,
|
61
|
-
"is_default_qc": self.is_default}
|
62
|
-
|
63
|
-
|
64
|
-
class OperatorSetConcat(OperatorsSetBase):
|
65
|
-
"""
|
66
|
-
Concatenate a list of operator sets to treat them similarly in different places (like fusing).
|
67
|
-
"""
|
68
|
-
def __init__(self, *opsets: OperatorsSet):
|
69
|
-
"""
|
70
|
-
Group a list of operation sets.
|
71
|
-
|
72
|
-
Args:
|
73
|
-
*opsets (OperatorsSet): List of operator sets to group.
|
74
|
-
"""
|
75
|
-
name = "_".join([a.name for a in opsets])
|
76
|
-
super().__init__(name=name)
|
77
|
-
self.op_set_list = opsets
|
78
|
-
self.qc_options = None # Concat have no qc options
|
79
|
-
|
80
|
-
def get_info(self) -> Dict[str,Any]:
|
81
|
-
"""
|
82
|
-
|
83
|
-
Returns: Info about the sets group as a dictionary.
|
84
|
-
|
85
|
-
"""
|
86
|
-
return {"name": self.name,
|
87
|
-
OPS_SET_LIST: [s.name for s in self.op_set_list]}
|
@@ -1,40 +0,0 @@
|
|
1
|
-
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
from typing import Any, Dict
|
16
|
-
|
17
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model
|
18
|
-
|
19
|
-
|
20
|
-
class TargetPlatformModelComponent:
|
21
|
-
"""
|
22
|
-
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
|
23
|
-
"""
|
24
|
-
def __init__(self, name: str):
|
25
|
-
"""
|
26
|
-
|
27
|
-
Args:
|
28
|
-
name: Name of component.
|
29
|
-
"""
|
30
|
-
self.name = name
|
31
|
-
_current_tp_model.get().append_component(self)
|
32
|
-
|
33
|
-
def get_info(self) -> Dict[str, Any]:
|
34
|
-
"""
|
35
|
-
|
36
|
-
Returns: Get information about the component to display (return an empty dictionary.
|
37
|
-
the actual component should fill it with info).
|
38
|
-
|
39
|
-
"""
|
40
|
-
return {}
|
{mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/LICENSE.md
RENAMED
File without changes
|
{mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|