mct-nightly 0.0.0__py3-none-any.whl → 1.1.0.01122021-003325__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/METADATA +3 -2
  2. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/RECORD +31 -38
  3. model_compression_toolkit/__init__.py +2 -6
  4. model_compression_toolkit/common/base_substitutions.py +1 -0
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +9 -12
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +8 -21
  7. model_compression_toolkit/common/collectors/histogram_collector.py +1 -1
  8. model_compression_toolkit/common/graph/base_graph.py +2 -4
  9. model_compression_toolkit/common/graph/graph_matchers.py +3 -1
  10. model_compression_toolkit/common/graph/graph_searches.py +3 -1
  11. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +1 -2
  12. model_compression_toolkit/common/network_editors/node_filters.py +1 -0
  13. model_compression_toolkit/common/quantization/quantization_params_generation/lp_selection.py +1 -1
  14. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +3 -5
  15. model_compression_toolkit/common/quantization/quantize_graph_weights.py +4 -7
  16. model_compression_toolkit/common/quantization/quantize_node.py +3 -5
  17. model_compression_toolkit/keras/__init__.py +2 -0
  18. model_compression_toolkit/keras/back2framework/model_builder.py +24 -1
  19. model_compression_toolkit/{common → keras/back2framework}/model_collector.py +9 -18
  20. model_compression_toolkit/keras/default_framework_info.py +0 -1
  21. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +57 -10
  22. model_compression_toolkit/keras/graph_substitutions/substituter.py +171 -0
  23. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +26 -6
  24. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +12 -5
  25. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +3 -4
  26. model_compression_toolkit/keras/quantization_facade.py +524 -188
  27. model_compression_toolkit/keras/reader/connectivity_handler.py +4 -1
  28. model_compression_toolkit/keras/visualization/nn_visualizer.py +1 -2
  29. model_compression_toolkit/common/framework_implementation.py +0 -239
  30. model_compression_toolkit/common/gptq/__init__.py +0 -14
  31. model_compression_toolkit/common/gptq/gptq_config.py +0 -65
  32. model_compression_toolkit/common/model_builder_mode.py +0 -34
  33. model_compression_toolkit/common/post_training_quantization.py +0 -459
  34. model_compression_toolkit/common/substitutions/__init__.py +0 -14
  35. model_compression_toolkit/common/substitutions/apply_substitutions.py +0 -40
  36. model_compression_toolkit/keras/keras_implementation.py +0 -256
  37. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/LICENSE +0 -0
  38. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/WHEEL +0 -0
  39. {mct_nightly-0.0.0.dist-info → mct_nightly-1.1.0.1122021.post3325.dist-info}/top_level.txt +0 -0
@@ -1,256 +0,0 @@
1
- from typing import List, Any, Tuple, Callable
2
-
3
- import numpy as np
4
- import tensorflow as tf
5
- from tensorflow.keras.models import Model
6
-
7
- from model_compression_toolkit import QuantizationConfig, FrameworkInfo, common, GradientPTQConfig, \
8
- MixedPrecisionQuantizationConfig
9
- from model_compression_toolkit.common import Graph, Node
10
- from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
11
- from model_compression_toolkit.common.model_builder_mode import ModelBuilderMode
12
- from model_compression_toolkit.common.user_info import UserInformation
13
- from model_compression_toolkit.keras.back2framework.model_builder import model_builder
14
- from model_compression_toolkit.keras.default_framework_info import DEFAULT_KERAS_INFO
15
- from model_compression_toolkit.keras.gradient_ptq.training_wrapper import gptq_training_wrapper
16
- from model_compression_toolkit.keras.graph_substitutions.substitutions.activation_decomposition import \
17
- ActivationDecomposition
18
- from model_compression_toolkit.keras.graph_substitutions.substitutions.batchnorm_folding import \
19
- BatchNormalizationFolding
20
- from model_compression_toolkit.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
21
- InputScalingWithPad
22
- from model_compression_toolkit.keras.graph_substitutions.substitutions.mark_activation import MarkActivation
23
- from model_compression_toolkit.keras.graph_substitutions.substitutions.relu_bound_correction import \
24
- ReLUBoundCorrection
25
- from model_compression_toolkit.keras.graph_substitutions.substitutions.remove_relu_upper_bound import \
26
- RemoveReLUUpperBound
27
- from model_compression_toolkit.keras.graph_substitutions.substitutions.scale_equalization import \
28
- ScaleEqualization, ScaleEqualizationWithPad, ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
29
- from model_compression_toolkit.keras.graph_substitutions.substitutions.separableconv_decomposition import \
30
- SeparableConvDecomposition
31
- from model_compression_toolkit.keras.graph_substitutions.substitutions.shift_negative_activation import \
32
- apply_shift_negative_correction
33
- from model_compression_toolkit.keras.mixed_precision.sensitivity_evaluation import get_sensitivity_evaluation
34
- from model_compression_toolkit.keras.reader.reader import model_reader
35
- from model_compression_toolkit.keras.tensor_marking import get_node_stats_collector
36
- import model_compression_toolkit.keras.constants as keras_constants
37
-
38
-
39
- class KerasImplementation(FrameworkImplementation):
40
- """
41
- An class with implemented methods to support optimizing Keras models.
42
- """
43
-
44
- def __init__(self):
45
- super().__init__()
46
-
47
- @property
48
- def constants(self):
49
- """
50
-
51
- Returns: Module of Keras constants.
52
-
53
- """
54
- return keras_constants
55
-
56
- def model_reader(self, model: Model) -> Graph:
57
- """
58
- Convert a framework's model into a graph.
59
- Args:
60
- model: Framework's model.
61
-
62
- Returns:
63
- Graph representing the input model.
64
- """
65
- return model_reader(model)
66
-
67
- def to_numpy(self, tensor: tf.Tensor) -> np.ndarray:
68
- """
69
- Convert framework's tensor to a Numpy array.
70
- Args:
71
- tensor: Framework's tensor.
72
-
73
- Returns:
74
- Numpy array converted from the input tensor.
75
- """
76
- return tensor.numpy()
77
-
78
- def model_builder(self,
79
- graph: Graph,
80
- mode: ModelBuilderMode,
81
- append2output: List[Any] = None,
82
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO) -> Tuple[Model, UserInformation]:
83
- """
84
- Build a Keras model from a graph.
85
- The mode determines how the model should be build. append2output is a list of Nodes
86
- to set as the model outputs.
87
-
88
- Args:
89
- graph: Graph to build the model from it.
90
- mode: Mode for how to build the model.
91
- append2output: List of Nodes to set as the model's outputs.
92
- fw_info: FrameworkInfo object with information about the specific framework's model
93
-
94
- Returns:
95
- A tuple of the Keras model that was built and an UserInformation object.
96
- """
97
- return model_builder(graph,
98
- mode,
99
- append2output,
100
- fw_info)
101
-
102
- def shift_negative_correction(self,
103
- graph: Graph,
104
- qc: QuantizationConfig,
105
- fw_info: FrameworkInfo) -> Graph:
106
- """
107
- Apply shift negative correction (SNC) on a graph.
108
-
109
- Args:
110
- graph: Graph to apply SNC on.
111
- qc: Quantization configuration.
112
- fw_info: FrameworkInfo object with information about the specific framework's model.
113
-
114
- Returns:
115
- Graph after SNC.
116
- """
117
- return apply_shift_negative_correction(graph,
118
- qc,
119
- fw_info)
120
-
121
- def attach_sc_to_node(self, node: Node,
122
- fw_info: FrameworkInfo) -> common.statistics_collector.BaseStatsContainer:
123
- """
124
- Return a statistics collector that should be attached to a node's output
125
- during statistics collection.
126
-
127
- Args:
128
- node: Node to return its collector.
129
- fw_info: FrameworkInfo object with information about the specific framework's model
130
-
131
- Returns:
132
- Statistics collector for the node.
133
- """
134
- return get_node_stats_collector(node,
135
- fw_info)
136
-
137
- def get_substitutions_marking(self) -> List[common.BaseSubstitution]:
138
- """
139
-
140
- Returns: A list of the framework substitutions used for marking
141
- points we fuse.
142
-
143
- """
144
- return [MarkActivation()]
145
-
146
- def get_substitutions_pre_statistics_collection(self) -> List[common.BaseSubstitution]:
147
- """
148
-
149
- Returns: A list of the framework substitutions used before we build a quantized model.
150
-
151
- """
152
- return [SeparableConvDecomposition(),
153
- ActivationDecomposition(),
154
- BatchNormalizationFolding()]
155
-
156
- def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[
157
- common.BaseSubstitution]:
158
- """
159
- Return a list of the framework substitutions used after we collect statistics.
160
-
161
- Args:
162
- quant_config: QuantizationConfig to determine which substitutions to return.
163
-
164
- Returns:
165
- A list of the framework substitutions used after we collect statistics.
166
- """
167
- substitutions_list = []
168
- if quant_config.input_scaling:
169
- substitutions_list.append(InputScaling())
170
- substitutions_list.append(InputScalingWithPad())
171
-
172
- if quant_config.relu_unbound_correction:
173
- substitutions_list.append(ReLUBoundCorrection())
174
- return substitutions_list
175
-
176
- def get_substitutions_channel_equalization(self,
177
- quant_config: QuantizationConfig,
178
- fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
179
- """
180
- Return a list of the framework substitutions used for channel equalization.
181
-
182
- Args:
183
- quant_config: QuantizationConfig to determine which substitutions to return.
184
- fw_info: FrameworkInfo object with information about the specific framework's model.
185
-
186
- Returns:
187
- A list of the framework substitutions used after we collect statistics.
188
- """
189
- substitutions_list = []
190
- if quant_config.activation_channel_equalization:
191
- substitutions_list.extend([ScaleEqualization(quant_config, fw_info),
192
- ScaleEqualizationWithPad(quant_config, fw_info),
193
- ScaleEqualizationMidActivation(quant_config, fw_info),
194
- ScaleEqualizationMidActivationWithPad(quant_config, fw_info)])
195
- return substitutions_list
196
-
197
- def get_substitutions_pre_build(self) -> List[common.BaseSubstitution]:
198
- """
199
-
200
- Returns: A list of the framework substitutions used before we build a quantized model.
201
-
202
- """
203
-
204
- return [RemoveReLUUpperBound()]
205
-
206
- def gptq_training(self,
207
- graph: Graph,
208
- representative_data_gen: Callable,
209
- gptq_config: GradientPTQConfig,
210
- fw_info: FrameworkInfo) -> Graph:
211
- """
212
- Update a graph using GPTQ after minimizing the loss between the float model's output
213
- and the quantized model's outputs.
214
-
215
- Args:
216
- graph: Graph to fine-tune.
217
- representative_data_gen: Dataset to use for inputs of the models.
218
- gptq_config: GradientPTQConfig with configuration for the fine-tuning process.
219
- fw_info: FrameworkInfo object with information about the specific framework's model.
220
-
221
- Returns:
222
- Updated graph after GPTQ.
223
- """
224
-
225
- return gptq_training_wrapper(graph,
226
- representative_data_gen,
227
- gptq_config,
228
- fw_info)
229
-
230
-
231
- def get_sensitivity_evaluation_fn(self,
232
- graph: Graph,
233
- quant_config: MixedPrecisionQuantizationConfig,
234
- metrics_weights: np.ndarray,
235
- representative_data_gen: Callable,
236
- fw_info: FrameworkInfo) -> Callable:
237
- """
238
- Create and return a function to compute a sensitivity metric for a mixed-precision
239
- configuration (comparing to the float Keras model).
240
-
241
- Args:
242
- graph: Graph to build it's float and mixed-precision Keras models.
243
- quant_config: QuantizationConfig of how the model should be quantized.
244
- metrics_weights: Array of weights to weight the sensitivity among different layers.
245
- representative_data_gen: Dataset to use for retrieving images for the models inputs.
246
- fw_info: FrameworkInfo object with information about the specific framework's model.
247
-
248
- Returns:
249
- A function that computes the metric.
250
- """
251
-
252
- return get_sensitivity_evaluation(graph,
253
- quant_config,
254
- metrics_weights,
255
- representative_data_gen,
256
- fw_info)