mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.02122021-003117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.2122021.post3117.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,13 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from collections import namedtuple
18
+
17
19
  import tensorflow as tf
18
20
  from tensorflow.python.keras.engine.node import Node as KerasNode
19
21
  from tensorflow.python.util.object_identity import Reference as TFReference
20
22
  from typing import List, Tuple
21
23
 
22
- from model_compression_toolkit.common.graph.base_graph import OutTensor
23
24
  from model_compression_toolkit.common.graph.node import Node
24
25
  from model_compression_toolkit.keras.reader.common import is_node_an_input_layer
25
26
  from model_compression_toolkit.keras.reader.node_builder import build_node
@@ -27,6 +28,8 @@ from model_compression_toolkit.keras.reader.node_builder import build_node
27
28
  keras = tf.keras
28
29
  layers = keras.layers
29
30
 
31
+ OutTensor = namedtuple('OutTensor', 'node node_out_index')
32
+
30
33
 
31
34
  class ConnectivityHandler(object):
32
35
  """
@@ -20,8 +20,7 @@ from matplotlib.figure import Figure
20
20
 
21
21
  from model_compression_toolkit.common import Graph
22
22
  from model_compression_toolkit.common.similarity_analyzer import compute_cs
23
- from model_compression_toolkit.keras.back2framework.model_builder import model_builder
24
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
23
+ from model_compression_toolkit.keras.back2framework.model_builder import model_builder, ModelBuilderMode
25
24
  from model_compression_toolkit.common.graph.node import Node
26
25
 
27
26
 
@@ -1,239 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- from abc import ABC, abstractmethod
16
- from typing import Callable, Any, List, Tuple
17
-
18
- import numpy as np
19
-
20
- from model_compression_toolkit import common, GradientPTQConfig, MixedPrecisionQuantizationConfig
21
- from model_compression_toolkit.common import Node
22
- from model_compression_toolkit.common.framework_info import FrameworkInfo
23
- from model_compression_toolkit.common.graph.base_graph import Graph
24
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
25
- from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
26
- from model_compression_toolkit.common.user_info import UserInformation
27
-
28
-
29
- class FrameworkImplementation(ABC):
30
- """
31
- An abstract class with abstract methods that should be implemented when supporting a new
32
- framework in MCT.
33
- """
34
-
35
- @property
36
- def constants(self):
37
- """
38
-
39
- Returns: Module of the framework constants.
40
-
41
- """
42
- raise Exception(f'{self.__class__.__name__} did not supply a constants module.')
43
-
44
- @abstractmethod
45
- def model_builder(self,
46
- graph: Graph,
47
- mode: ModelBuilderMode,
48
- append2output: List[Any],
49
- fw_info: FrameworkInfo) -> Tuple[Any, UserInformation]:
50
- """
51
- Build a framework model from a graph.
52
- The mode determines how the model should be build. append2output is a list of Nodes
53
- to set as the model outputs.
54
-
55
- Args:
56
- graph: Graph to build the model from it.
57
- mode: Mode for how to build the model.
58
- append2output: List of Nodes to set as the model's outputs.
59
- fw_info: FrameworkInfo object with information about the specific framework's model
60
-
61
- Returns:
62
- A tuple of the model that was built and an UserInformation object.
63
- """
64
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
65
- f'framework\'s model_builder method.')
66
-
67
- @abstractmethod
68
- def shift_negative_correction(self,
69
- graph: Graph,
70
- qc: QuantizationConfig,
71
- fw_info: FrameworkInfo) -> Graph:
72
- """
73
- Apply shift negative correction (SNC) on a graph.
74
-
75
- Args:
76
- graph: Graph to apply SNC on.
77
- qc: Quantization configuration.
78
- fw_info: FrameworkInfo object with information about the specific framework's model.
79
-
80
- Returns:
81
- Graph after SNC.
82
- """
83
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
84
- f'framework\'s apply_shift_negative_correction method.')
85
-
86
- @abstractmethod
87
- def to_numpy(self, tensor: Any) -> np.ndarray:
88
- """
89
- Convert framework's tensor to a Numpy array.
90
- Args:
91
- tensor: Framework's tensor.
92
-
93
- Returns:
94
- Numpy array converted from the input tensor.
95
- """
96
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
97
- f'framework\'s to_numpy method.')
98
-
99
- @abstractmethod
100
- def model_reader(self, model: Any) -> Graph:
101
- """
102
- Convert a framework's model into a graph.
103
- Args:
104
- model: Framework's model.
105
-
106
- Returns:
107
- Graph representing the input model.
108
- """
109
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
110
- f'framework\'s model_reader method.')
111
-
112
- @abstractmethod
113
- def attach_sc_to_node(self, node:Node,
114
- fw_info:FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
115
- """
116
- Return a statistics collector that should be attached to a node's output
117
- during statistics collection.
118
-
119
- Args:
120
- node: Node to return its collector.
121
- fw_info: FrameworkInfo object with information about the specific framework's model
122
-
123
- Returns:
124
- Statistics collector for the node.
125
- """
126
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
127
- f'framework\'s attach_sc_to_node method.')
128
-
129
- @abstractmethod
130
- def get_substitutions_marking(self) -> List[common.BaseSubstitution]:
131
- """
132
-
133
- Returns: A list of the framework substitutions used for marking
134
- points we fuse.
135
-
136
- """
137
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
138
- f'framework\'s get_substitutions_marking method.')
139
-
140
- @abstractmethod
141
- def get_substitutions_pre_statistics_collection(self) -> List[common.BaseSubstitution]:
142
- """
143
-
144
- Returns: A list of the framework substitutions used before we collect statistics.
145
-
146
- """
147
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
148
- f'framework\'s get_substitutions_pre_statistics_collection method.')
149
-
150
-
151
- @abstractmethod
152
- def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
153
- """
154
-
155
- Returns: A list of the framework substitutions used before we build a quantized model.
156
-
157
- """
158
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
159
- f'framework\'s get_substitutions_pre_build method.')
160
-
161
- @abstractmethod
162
- def get_substitutions_post_statistics_collection(self, quant_config:QuantizationConfig) -> List[common.BaseSubstitution]:
163
- """
164
- Return a list of the framework substitutions used after we collect statistics.
165
-
166
- Args:
167
- quant_config: QuantizationConfig to determine which substitutions to return.
168
-
169
- Returns:
170
- A list of the framework substitutions used after we collect statistics.
171
- """
172
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
173
- f'framework\'s get_substitutions_post_statistics_collection method.')
174
-
175
- @abstractmethod
176
- def get_substitutions_channel_equalization(self,
177
- quant_config: QuantizationConfig,
178
- fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
179
- """
180
- Return a list of the framework substitutions used for channel equalization.
181
-
182
- Args:
183
- quant_config: QuantizationConfig to determine which substitutions to return.
184
- fw_info: FrameworkInfo object with information about the specific framework's model.
185
-
186
- Returns:
187
- A list of the framework substitutions used after we collect statistics.
188
- """
189
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
190
- f'framework\'s get_substitutions_channel_equalization method.')
191
-
192
- @abstractmethod
193
- def gptq_training(self,
194
- graph: Graph,
195
- representative_data_gen: Callable,
196
- gptq_config: GradientPTQConfig,
197
- fw_info: FrameworkInfo) -> Graph:
198
- """
199
- Update a graph using GPTQ after minimizing the loss between the float model's output
200
- and the quantized model's outputs.
201
-
202
- Args:
203
- graph: Graph to fine-tune.
204
- representative_data_gen: Dataset to use for inputs of the models.
205
- gptq_config: GradientPTQConfig with configuration for the fine-tuning process.
206
- fw_info: FrameworkInfo object with information about the specific framework's model.
207
-
208
- Returns:
209
- Updated graph after GPTQ.
210
- """
211
-
212
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
213
- f'framework\'s gptq_training method.')
214
-
215
- @abstractmethod
216
- def get_sensitivity_evaluation_fn(self,
217
- graph: Graph,
218
- quant_config: MixedPrecisionQuantizationConfig,
219
- metrics_weights: np.ndarray,
220
- representative_data_gen: Callable,
221
- fw_info: FrameworkInfo) -> Callable:
222
- """
223
- Create and return a function to compute a sensitivity metric for a mixed-precision
224
- configuration (comparing to the float model).
225
-
226
- Args:
227
- graph: Graph to build it's float and mixed-precision models.
228
- quant_config: QuantizationConfig of how the model should be quantized.
229
- metrics_weights: Array of weights to weight the sensitivity among different layers.
230
- representative_data_gen: Dataset to use for retrieving images for the models inputs.
231
- fw_info: FrameworkInfo object with information about the specific framework's model.
232
-
233
- Returns:
234
- A function that computes the metric.
235
- """
236
- raise NotImplemented(f'{self.__class__.__name__} have to implement the '
237
- f'framework\'s get_sensitivity_evaluation_fn method.')
238
-
239
-
@@ -1,14 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,65 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Callable, Any
17
-
18
-
19
- class GradientPTQConfig:
20
- """
21
- Configuration to use for quantization with GradientPTQ (experimental).
22
- """
23
-
24
- def __init__(self,
25
- n_iter: int,
26
- optimizer: Any,
27
- loss: Callable = None,
28
- log_function: Callable = None,
29
- train_bias: bool = True,
30
- representative_data_gen: Callable = None):
31
- """
32
- Initialize a GradientPTQConfig.
33
-
34
- Args:
35
- n_iter (int): Number of iterations to train.
36
- optimizer (OptimizerV2): Optimizer to use.
37
- loss (Callable): the loss to use. should accept 2 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors.
38
- log_function (Callable): Function to log information about the GPTQ process.
39
- train_bias (bool): Whether to update the bias during the training or not.
40
- representative_data_gen (Callable): Dataset generator.
41
-
42
- Examples:
43
- Create a GradientPTQConfig to run for 5 iteration and uses a random dataset generator:
44
-
45
- >>> import numpy as np
46
- >>> def repr_datagen(): return [np.random.random((1,224,224,3))]
47
- >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen)
48
-
49
- An optimizer can be passed:
50
-
51
- >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, optimizer=tf.keras.optimizers.Nadam(learning_rate=0.2))
52
-
53
- To disable the biases training, one may set train_bias to False (enabled by default):
54
-
55
- >>> gptq_conf = GradientPTQConfig(n_iter=5, representative_data_gen=repr_datagen, train_bias=False)
56
-
57
- The configuration can then be passed to :func:`~model_compression_toolkit.keras_post_training_quantization`.
58
-
59
- """
60
- self.n_iter = n_iter
61
- self.optimizer = optimizer
62
- self.loss = loss
63
- self.log_function = log_function
64
- self.train_bias = train_bias
65
- self.representative_data_gen = representative_data_gen
@@ -1,34 +0,0 @@
1
- # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from enum import Enum
17
-
18
-
19
- class ModelBuilderMode(Enum):
20
- """
21
- Mode for building the model back from a graph:
22
- FLOAT - Build model for statistics collection. Model's outputs list contain all output tensors of all nodes
23
- in the graph.
24
- QUANTIZED - Build a quantized model using the nodes' quantization attributes for adding
25
- quantization nodes to the model.
26
- GPTQ - Build a quantized model using the nodes' quantization attributes for wrapping
27
- layers with QuantizeWrapper and output comparing points.
28
- MIXEDPRECISION - Build a quantized model where the layers that their weights should be quantized
29
- can use different quantized weights according to the possible bitwidths of each layer.
30
- """
31
- FLOAT = 0
32
- QUANTIZED = 1
33
- GPTQ = 2
34
- MIXEDPRECISION = 3