mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
- model_compression_toolkit/common/__init__.py +2 -2
- model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
- model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
- model_compression_toolkit/common/collectors/mean_collector.py +2 -3
- model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
- model_compression_toolkit/common/constants.py +0 -1
- model_compression_toolkit/common/framework_implementation.py +6 -22
- model_compression_toolkit/common/framework_info.py +7 -39
- model_compression_toolkit/common/graph/__init__.py +1 -1
- model_compression_toolkit/common/graph/base_graph.py +34 -34
- model_compression_toolkit/common/graph/edge.py +3 -3
- model_compression_toolkit/common/graph/graph_matchers.py +3 -3
- model_compression_toolkit/common/graph/graph_searches.py +4 -4
- model_compression_toolkit/common/graph/graph_vis.py +116 -0
- model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
- model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
- model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
- model_compression_toolkit/common/model_collector.py +12 -14
- model_compression_toolkit/common/network_editors/actions.py +23 -19
- model_compression_toolkit/common/post_training_quantization.py +7 -20
- model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
- model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
- model_compression_toolkit/common/quantization/quantization_config.py +6 -6
- model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
- model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
- model_compression_toolkit/common/quantization/quantize_node.py +2 -2
- model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
- model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
- model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
- model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
- model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
- model_compression_toolkit/keras/constants.py +0 -3
- model_compression_toolkit/keras/default_framework_info.py +7 -33
- model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
- model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
- model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
- model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
- model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
- model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
- model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
- model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
- model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
- model_compression_toolkit/keras/keras_implementation.py +8 -28
- model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/keras/quantization_facade.py +1 -5
- model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
- model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
- model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
- model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
- model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
- model_compression_toolkit/keras/reader/common.py +11 -9
- model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
- model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
- model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
- model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
- model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
- model_compression_toolkit/keras/reader/node_builder.py +15 -65
- model_compression_toolkit/keras/reader/reader.py +5 -5
- model_compression_toolkit/keras/tensor_marking.py +113 -0
- model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
- model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
- model_compression_toolkit/common/graph/functional_node.py +0 -59
- model_compression_toolkit/common/model_validation.py +0 -43
- model_compression_toolkit/common/node_prior_info.py +0 -29
- model_compression_toolkit/keras/keras_model_validation.py +0 -38
- model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
- {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
|
@@ -21,12 +21,13 @@ import networkx as nx
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
23
|
from networkx.algorithms.dag import topological_sort
|
|
24
|
+
from model_compression_toolkit import common
|
|
24
25
|
from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
|
|
25
26
|
from model_compression_toolkit.common.graph.edge import Edge, convert_to_edge
|
|
26
27
|
from model_compression_toolkit.common.graph.graph_searches import GraphSearches
|
|
27
|
-
from model_compression_toolkit.common.graph.
|
|
28
|
-
from model_compression_toolkit.common.
|
|
29
|
-
from model_compression_toolkit.common.
|
|
28
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
29
|
+
from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
|
|
30
|
+
from model_compression_toolkit.common.statistics_collector import scale_statistics, shift_statistics
|
|
30
31
|
from model_compression_toolkit.common.user_info import UserInformation
|
|
31
32
|
from model_compression_toolkit.common.logger import Logger
|
|
32
33
|
|
|
@@ -39,8 +40,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
def __init__(self,
|
|
42
|
-
nodes: List[
|
|
43
|
-
input_nodes: List[
|
|
43
|
+
nodes: List[Node],
|
|
44
|
+
input_nodes: List[Node],
|
|
44
45
|
output_nodes: List[OutTensor],
|
|
45
46
|
edge_list: List[Edge],
|
|
46
47
|
**attr):
|
|
@@ -79,7 +80,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
79
80
|
|
|
80
81
|
return np.unique([n.op for n in self.nodes()])
|
|
81
82
|
|
|
82
|
-
def get_inputs(self) -> List[
|
|
83
|
+
def get_inputs(self) -> List[Node]:
|
|
83
84
|
"""
|
|
84
85
|
Returns: List containing the model input nodes.
|
|
85
86
|
"""
|
|
@@ -94,7 +95,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
94
95
|
return self.output_nodes
|
|
95
96
|
|
|
96
97
|
def set_inputs(self,
|
|
97
|
-
input_nodes: List[
|
|
98
|
+
input_nodes: List[Node]):
|
|
98
99
|
"""
|
|
99
100
|
Set the graph inputs dictionary.
|
|
100
101
|
Args:
|
|
@@ -114,8 +115,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
114
115
|
self.output_nodes = output_nodes
|
|
115
116
|
|
|
116
117
|
def set_out_stats_collector_to_node(self,
|
|
117
|
-
n:
|
|
118
|
-
stats_collector:
|
|
118
|
+
n: Node,
|
|
119
|
+
stats_collector: BaseStatsContainer):
|
|
119
120
|
"""
|
|
120
121
|
Set an output statistics collector of a node in the graph, and set this statistics collector as an input
|
|
121
122
|
statistics collector of nodes next to this given node.
|
|
@@ -163,7 +164,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
163
164
|
self.node_to_in_stats_collector.update({oe.sink_node: stats_collector})
|
|
164
165
|
|
|
165
166
|
def get_out_stats_collector(self,
|
|
166
|
-
n:
|
|
167
|
+
n: Node) -> BaseStatsContainer:
|
|
167
168
|
"""
|
|
168
169
|
Get the output statistics collector of a node containing output statistics of the node.
|
|
169
170
|
Args:
|
|
@@ -172,10 +173,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
172
173
|
Returns:
|
|
173
174
|
Tensor containing output statistics of the node.
|
|
174
175
|
"""
|
|
176
|
+
|
|
175
177
|
return self.node_to_out_stats_collector.get(n)
|
|
176
178
|
|
|
177
179
|
def get_in_stats_collector(self,
|
|
178
|
-
n:
|
|
180
|
+
n: Node) -> BaseStatsContainer:
|
|
179
181
|
"""
|
|
180
182
|
Get the input statistics collector of a node containing input statistics of the node.
|
|
181
183
|
Args:
|
|
@@ -191,7 +193,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
191
193
|
return sc
|
|
192
194
|
|
|
193
195
|
def scale_stats_collector(self,
|
|
194
|
-
node:
|
|
196
|
+
node: Node,
|
|
195
197
|
scale_factor: np.ndarray):
|
|
196
198
|
"""
|
|
197
199
|
Scale the output statistics of a node in the graph by a given scaling factor.
|
|
@@ -209,7 +211,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
209
211
|
self.set_out_stats_collector_to_node(node, scaled_sc)
|
|
210
212
|
|
|
211
213
|
def shift_stats_collector(self,
|
|
212
|
-
node:
|
|
214
|
+
node: Node,
|
|
213
215
|
shift_value: np.ndarray):
|
|
214
216
|
"""
|
|
215
217
|
Shift the output statistics of a node in the graph by a given value.
|
|
@@ -227,7 +229,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
227
229
|
self.set_out_stats_collector_to_node(node, shifted_sc)
|
|
228
230
|
|
|
229
231
|
def find_node_by_name(self,
|
|
230
|
-
name: str) -> List[
|
|
232
|
+
name: str) -> List[Node]:
|
|
231
233
|
"""
|
|
232
234
|
Find and return a list of nodes by a name.
|
|
233
235
|
|
|
@@ -241,7 +243,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
241
243
|
return [n for n in self.nodes if n.name == name]
|
|
242
244
|
|
|
243
245
|
def get_next_nodes(self,
|
|
244
|
-
node_obj:
|
|
246
|
+
node_obj: Node) -> List[Node]:
|
|
245
247
|
"""
|
|
246
248
|
Get next nodes (in a topological order) of a node.
|
|
247
249
|
|
|
@@ -256,7 +258,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
256
258
|
return [edges_list.sink_node for edges_list in self.out_edges(node_obj)]
|
|
257
259
|
|
|
258
260
|
def get_prev_nodes(self,
|
|
259
|
-
node_obj:
|
|
261
|
+
node_obj: Node) -> List[Node]:
|
|
260
262
|
"""
|
|
261
263
|
Get previous nodes (in a topological order) of a node.
|
|
262
264
|
|
|
@@ -271,8 +273,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
271
273
|
return [edges_list.source_node for edges_list in self.incoming_edges(node_obj)]
|
|
272
274
|
|
|
273
275
|
def reconnect_out_edges(self,
|
|
274
|
-
current_node:
|
|
275
|
-
new_node:
|
|
276
|
+
current_node: Node,
|
|
277
|
+
new_node: Node):
|
|
276
278
|
"""
|
|
277
279
|
Connect all outgoing edges of a node to be outgoing edges of a different node
|
|
278
280
|
(useful when replacing a node during substitutions).
|
|
@@ -287,8 +289,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
287
289
|
self.remove_edge(current_node, oe.sink_node)
|
|
288
290
|
|
|
289
291
|
def reconnect_in_edges(self,
|
|
290
|
-
current_node:
|
|
291
|
-
new_node:
|
|
292
|
+
current_node: Node,
|
|
293
|
+
new_node: Node):
|
|
292
294
|
"""
|
|
293
295
|
Connect all incoming edges of a node to be incoming edges of a different node
|
|
294
296
|
(useful when replacing a node during substitutions).
|
|
@@ -303,8 +305,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
303
305
|
self.remove_edge(ie.source_node, current_node)
|
|
304
306
|
|
|
305
307
|
def replace_output_node(self,
|
|
306
|
-
current_node:
|
|
307
|
-
new_node:
|
|
308
|
+
current_node: Node,
|
|
309
|
+
new_node: Node):
|
|
308
310
|
"""
|
|
309
311
|
If a node is being substituted with another node and it is an output node, the graph's outputs
|
|
310
312
|
should be updated as well. This function takes care of it by going over the graph's outputs, and
|
|
@@ -325,8 +327,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
325
327
|
self.set_outputs(new_graph_outputs)
|
|
326
328
|
|
|
327
329
|
def remove_node(self,
|
|
328
|
-
node_to_remove:
|
|
329
|
-
new_graph_inputs: List[
|
|
330
|
+
node_to_remove: Node,
|
|
331
|
+
new_graph_inputs: List[Node] = None,
|
|
330
332
|
new_graph_outputs: List[OutTensor] = None):
|
|
331
333
|
"""
|
|
332
334
|
Remove a node from the graph. A new inputs/outputs lists can be passed in case the node is currently an
|
|
@@ -365,7 +367,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
365
367
|
super().remove_node(node_to_remove)
|
|
366
368
|
|
|
367
369
|
def incoming_edges(self,
|
|
368
|
-
n:
|
|
370
|
+
n: Node,
|
|
369
371
|
sort_by_attr: str = None) -> List[Edge]:
|
|
370
372
|
"""
|
|
371
373
|
Get a list of incoming edges of a node. If sort_by_attr is passed, the returned list
|
|
@@ -386,7 +388,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
386
388
|
return input_edges
|
|
387
389
|
|
|
388
390
|
def out_edges(self,
|
|
389
|
-
n:
|
|
391
|
+
n: Node,
|
|
390
392
|
sort_by_attr: str = None) -> List[Edge]:
|
|
391
393
|
"""
|
|
392
394
|
Get a list of outgoing edges of a node. If sort_by_attr is passed, the returned list
|
|
@@ -416,8 +418,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
416
418
|
memory += n.get_memory_bytes()
|
|
417
419
|
return memory
|
|
418
420
|
|
|
419
|
-
def get_configurable_sorted_nodes_names(self,
|
|
420
|
-
include_reused_nodes: bool = False) -> List[str]:
|
|
421
|
+
def get_configurable_sorted_nodes_names(self, include_reused_nodes: bool = False) -> List[str]:
|
|
421
422
|
"""
|
|
422
423
|
Get a list of nodes' names that can be configured (namely, has one or
|
|
423
424
|
more weight qc candidate). The names are sorted according to the topological
|
|
@@ -430,15 +431,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
430
431
|
more weight qc candidate) sorted topology.
|
|
431
432
|
|
|
432
433
|
"""
|
|
433
|
-
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes
|
|
434
|
+
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes)]
|
|
434
435
|
return sorted_names
|
|
435
436
|
|
|
436
|
-
def get_configurable_sorted_nodes(self,
|
|
437
|
-
include_reused_nodes: bool = False) -> List[BaseNode]:
|
|
437
|
+
def get_configurable_sorted_nodes(self, include_reused_nodes: bool = False) -> List[Node]:
|
|
438
438
|
"""
|
|
439
439
|
Get a list of nodes that can be configured (namely, has one or
|
|
440
|
-
more weight qc candidate
|
|
441
|
-
|
|
440
|
+
more weight qc candidate). The nodes are sorted according to the topological
|
|
441
|
+
order of the graph.
|
|
442
442
|
|
|
443
443
|
Args:
|
|
444
444
|
include_reused_nodes: Whether or not to include reused nodes (False by default).
|
|
@@ -450,7 +450,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
|
450
450
|
sorted_configurable_nodes = []
|
|
451
451
|
sorted_nodes = list(topological_sort(self))
|
|
452
452
|
for n in sorted_nodes:
|
|
453
|
-
if n.
|
|
453
|
+
if n.candidates_weights_quantization_cfg is not None:
|
|
454
454
|
if not n.reuse or include_reused_nodes:
|
|
455
455
|
if len(n.candidates_weights_quantization_cfg) >= 1:
|
|
456
456
|
sorted_configurable_nodes.append(n)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
from typing import Any, Dict
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.common.graph.
|
|
19
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
20
20
|
|
|
21
21
|
# Edge attributes:
|
|
22
22
|
EDGE_SOURCE_INDEX = 'source_index'
|
|
@@ -29,8 +29,8 @@ class Edge(object):
|
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
def __init__(self,
|
|
32
|
-
source_node:
|
|
33
|
-
sink_node:
|
|
32
|
+
source_node: Node,
|
|
33
|
+
sink_node: Node,
|
|
34
34
|
source_index: int,
|
|
35
35
|
sink_index: int):
|
|
36
36
|
"""
|
|
@@ -18,7 +18,7 @@ from typing import Any, List
|
|
|
18
18
|
|
|
19
19
|
from tensorflow.keras.layers import Layer
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.common.graph.
|
|
21
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
22
22
|
from model_compression_toolkit.common.matchers import node_matcher, walk_matcher, edge_matcher
|
|
23
23
|
|
|
24
24
|
|
|
@@ -92,7 +92,7 @@ class EdgeMatcher(edge_matcher.BaseEdgeMatcher):
|
|
|
92
92
|
class EdgeMatcher to check if an edge matches an edge that EdgeMatcher contains.
|
|
93
93
|
"""
|
|
94
94
|
|
|
95
|
-
def __init__(self, source_matcher:
|
|
95
|
+
def __init__(self, source_matcher: Node, target_matcher: Node):
|
|
96
96
|
"""
|
|
97
97
|
Init an EdgeMatcher object.
|
|
98
98
|
|
|
@@ -125,7 +125,7 @@ class WalkMatcher(walk_matcher.WalkMatcherList):
|
|
|
125
125
|
Class WalkMatcher to check if a list of nodes matches another list of nodes.
|
|
126
126
|
"""
|
|
127
127
|
|
|
128
|
-
def __init__(self, matcher_list: List[
|
|
128
|
+
def __init__(self, matcher_list: List[Node]):
|
|
129
129
|
"""
|
|
130
130
|
Init a WalkMatcher object.
|
|
131
131
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from abc import ABC
|
|
17
17
|
from typing import List, Any
|
|
18
18
|
|
|
19
|
-
from model_compression_toolkit.common.graph.
|
|
19
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
20
20
|
from model_compression_toolkit.common.matchers import node_matcher, base_graph_filter, edge_matcher
|
|
21
21
|
from model_compression_toolkit.common.matchers.walk_matcher import WalkMatcherList
|
|
22
22
|
|
|
@@ -59,7 +59,7 @@ class GraphSearches(base_graph_filter.BaseGraphFilter, ABC):
|
|
|
59
59
|
|
|
60
60
|
return edge_list
|
|
61
61
|
|
|
62
|
-
def _walk_filter(self, walk_matcher: WalkMatcherList) -> List[
|
|
62
|
+
def _walk_filter(self, walk_matcher: WalkMatcherList) -> List[Node]:
|
|
63
63
|
"""
|
|
64
64
|
Search for a list of nodes which match the list in walk_matcher.
|
|
65
65
|
If one the nodes in the list (that was found in the graph) has more than one output,
|
|
@@ -72,8 +72,8 @@ class GraphSearches(base_graph_filter.BaseGraphFilter, ABC):
|
|
|
72
72
|
A list of nodes which match the list in walk_matcher.
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
|
-
def walk_match(node:
|
|
76
|
-
node_list: List[
|
|
75
|
+
def walk_match(node: Node,
|
|
76
|
+
node_list: List[Node],
|
|
77
77
|
index: int,
|
|
78
78
|
node_matcher_list: list) -> Any:
|
|
79
79
|
"""
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import os
|
|
19
|
+
from typing import Any, Dict, Tuple
|
|
20
|
+
|
|
21
|
+
from model_compression_toolkit import common
|
|
22
|
+
from model_compression_toolkit.common.graph.node import Node
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def check_str(in_sting: str) -> bool:
|
|
26
|
+
"""
|
|
27
|
+
Checks if an open bracket is in a string.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
in_sting: String to check.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Whether an open bracket is in the string or not.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
return '(' not in in_sting and '[' not in in_sting and '{' not in in_sting
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def node_dict(n: Node) -> Dict[str, Any]:
|
|
40
|
+
"""
|
|
41
|
+
Get a dictionary with a node's attributes for displaying.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
n: Node to get its attributes to display.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A dictionary with params of the node to display when visualizing the graph.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
framework_attr = {k: str(v) for k, v in n.framework_attr.items() if check_str(str(v))}
|
|
51
|
+
framework_attr.update({k: str(v) for k, v in n.quantization_attr.items() if check_str(str(v))})
|
|
52
|
+
framework_attr.update({'op': n.layer_class.__name__})
|
|
53
|
+
|
|
54
|
+
if n.quantization_cfg is not None:
|
|
55
|
+
for k, v in n.quantization_cfg.activation_quantization_params.items():
|
|
56
|
+
framework_attr.update({k: str(v)})
|
|
57
|
+
framework_attr.update({'activation_is_signed': str(n.quantization_cfg.activation_is_signed)})
|
|
58
|
+
|
|
59
|
+
return {"id": n.name,
|
|
60
|
+
"group": n.layer_class.__name__,
|
|
61
|
+
"label": n.layer_class.__name__,
|
|
62
|
+
"title": "",
|
|
63
|
+
"properties": framework_attr}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def edge_dict(i: int,
|
|
67
|
+
edge: Tuple[Node, Node],
|
|
68
|
+
graph):
|
|
69
|
+
"""
|
|
70
|
+
Create a dictionary of attributes to visualize an edge in the graph.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
i: Edge's ID.
|
|
74
|
+
edge: Tuple of two nodes (source and destination).
|
|
75
|
+
graph: Graph the edge is in.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Dictionary of attributes to visualize an edge in a graph.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
return {"id": str(i),
|
|
82
|
+
"from": edge[0].name,
|
|
83
|
+
"to": edge[1].name,
|
|
84
|
+
"label": "",
|
|
85
|
+
"title": "",
|
|
86
|
+
"properties": graph.get_edge_data(edge[0], edge[1])}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def write_vis_graph(folder_path: str,
|
|
90
|
+
file_name: str,
|
|
91
|
+
graph):
|
|
92
|
+
"""
|
|
93
|
+
Create and save a json file containing data about the graph to visualize (such as nodes and edges).
|
|
94
|
+
folder_path and file_name determine where the json file will be saved.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
folder_path: Dir path to save the json file.
|
|
98
|
+
file_name: File name of the json file.
|
|
99
|
+
graph: Graph to visualize.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
os.makedirs(folder_path, exist_ok=True)
|
|
104
|
+
file_path = os.path.join(folder_path, file_name + '.json')
|
|
105
|
+
|
|
106
|
+
nodes_list = [node_dict(n) for n in graph.nodes()]
|
|
107
|
+
edges_list = [edge_dict(i, e, graph) for i, e in enumerate(graph.edges())]
|
|
108
|
+
graph_dict = {'nodes': nodes_list,
|
|
109
|
+
'edges': edges_list,
|
|
110
|
+
"options": {}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
with open(file_path, 'w') as json_file:
|
|
114
|
+
json.dump(graph_dict, json_file)
|
|
115
|
+
|
|
116
|
+
common.Logger.info(f"Writing Vis Graph to:{file_path}")
|
|
@@ -21,7 +21,7 @@ import numpy as np
|
|
|
21
21
|
from model_compression_toolkit.common.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class
|
|
24
|
+
class Node:
|
|
25
25
|
"""
|
|
26
26
|
Class to represent a node in a graph that represents the model.
|
|
27
27
|
"""
|
|
@@ -35,6 +35,7 @@ class BaseNode:
|
|
|
35
35
|
layer_class: type,
|
|
36
36
|
reuse: bool = False,
|
|
37
37
|
reuse_group: str = None,
|
|
38
|
+
op_call_args: Dict[str, Any] = {},
|
|
38
39
|
quantization_attr: Dict[str, Any] = None):
|
|
39
40
|
"""
|
|
40
41
|
Init a Node object.
|
|
@@ -48,8 +49,11 @@ class BaseNode:
|
|
|
48
49
|
layer_class: Class path of the layer this node represents.
|
|
49
50
|
reuse: Whether this node was duplicated and represents a reused layer.
|
|
50
51
|
reuse_group: Name of group of nodes from the same reused layer.
|
|
52
|
+
op_call_args: Arguments dictionary with values to pass when calling the layer.
|
|
51
53
|
quantization_attr: Attributes the node holds regarding how it should be quantized.
|
|
52
54
|
"""
|
|
55
|
+
|
|
56
|
+
|
|
53
57
|
self.name = name
|
|
54
58
|
self.framework_attr = framework_attr
|
|
55
59
|
self.quantization_attr = quantization_attr if quantization_attr is not None else dict()
|
|
@@ -62,25 +66,37 @@ class BaseNode:
|
|
|
62
66
|
self.activation_quantization_cfg = None
|
|
63
67
|
self.final_weights_quantization_cfg = None
|
|
64
68
|
self.candidates_weights_quantization_cfg = None
|
|
65
|
-
self.
|
|
69
|
+
self.output_quantization = True
|
|
70
|
+
self.op_call_args = op_call_args
|
|
71
|
+
|
|
72
|
+
def no_quantization(self) -> bool:
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
Returns: Whether NodeQuantizationConfig does not have activation params.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
return self.activation_quantization_cfg is None or \
|
|
79
|
+
(not self.activation_quantization_cfg.has_activation_quantization_params())
|
|
66
80
|
|
|
67
|
-
def
|
|
81
|
+
def weight_quantization(self) -> bool:
|
|
68
82
|
"""
|
|
69
83
|
|
|
70
|
-
Returns: Whether node
|
|
84
|
+
Returns: Whether node weights should be quantized
|
|
71
85
|
|
|
72
86
|
"""
|
|
73
|
-
return self.
|
|
87
|
+
return self.final_weights_quantization_cfg is not None and \
|
|
88
|
+
self.final_weights_quantization_cfg.has_weights_quantization_params() and \
|
|
89
|
+
self.final_weights_quantization_cfg.enable_weights_quantization
|
|
74
90
|
|
|
75
|
-
def
|
|
91
|
+
def activation_quantization(self) -> bool:
|
|
76
92
|
"""
|
|
77
93
|
|
|
78
|
-
Returns: Whether node
|
|
94
|
+
Returns: Whether node activation should be quantized
|
|
79
95
|
|
|
80
96
|
"""
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
97
|
+
return self.activation_quantization_cfg is not None and \
|
|
98
|
+
self.activation_quantization_cfg.has_activation_quantization_params() and \
|
|
99
|
+
self.activation_quantization_cfg.enable_activation_quantization
|
|
84
100
|
|
|
85
101
|
def __repr__(self):
|
|
86
102
|
"""
|
|
@@ -172,7 +188,7 @@ class BaseNode:
|
|
|
172
188
|
"""
|
|
173
189
|
shared_attributes = [CORRECTED_BIAS_ATTRIBUTE, WEIGHTS_NBITS_ATTRIBUTE]
|
|
174
190
|
attr = dict()
|
|
175
|
-
if self.
|
|
191
|
+
if self.candidates_weights_quantization_cfg is not None:
|
|
176
192
|
attr = copy.deepcopy(self.candidates_weights_quantization_cfg[0].__dict__)
|
|
177
193
|
for shared_attr in shared_attributes:
|
|
178
194
|
if shared_attr in attr:
|
|
@@ -18,13 +18,14 @@ import copy
|
|
|
18
18
|
from typing import Any, List
|
|
19
19
|
|
|
20
20
|
from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
|
|
21
|
-
from model_compression_toolkit.common import Graph,
|
|
21
|
+
from model_compression_toolkit.common import Graph, Node
|
|
22
22
|
from model_compression_toolkit.common.framework_info import FrameworkInfo
|
|
23
23
|
from model_compression_toolkit.common.logger import Logger
|
|
24
24
|
from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
|
|
25
25
|
MixedPrecisionQuantizationConfig
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
|
|
28
29
|
def set_bit_widths(quant_config: QuantizationConfig,
|
|
29
30
|
graph_to_set_bit_widths: Graph,
|
|
30
31
|
fw_info: FrameworkInfo = None,
|
|
@@ -36,10 +37,9 @@ def set_bit_widths(quant_config: QuantizationConfig,
|
|
|
36
37
|
Args:
|
|
37
38
|
quant_config: MixedPrecisionQuantizationConfig the graph was computed according to.
|
|
38
39
|
graph_to_set_bit_widths: A prepared for quantization graph to set its bit widths.
|
|
39
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel
|
|
40
|
-
|
|
41
|
-
bit_widths_config: MP configuration (a list of indices: one for each node's candidate
|
|
42
|
-
quantization configuration).
|
|
40
|
+
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
|
41
|
+
groups of layers by how they should be quantized, etc.)
|
|
42
|
+
bit_widths_config: MP configuration (a list of indices: one for each node's candidate quantization configuration).
|
|
43
43
|
|
|
44
44
|
"""
|
|
45
45
|
graph = copy.deepcopy(graph_to_set_bit_widths)
|
|
@@ -63,8 +63,7 @@ def set_bit_widths(quant_config: QuantizationConfig,
|
|
|
63
63
|
# Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
|
|
64
64
|
sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
|
|
65
65
|
for node in graph.nodes: # set a specific node qc for each node final weights qc
|
|
66
|
-
#
|
|
67
|
-
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
|
|
66
|
+
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2]) # if it's reused, take the configuration that the base node has
|
|
68
67
|
if node_name in sorted_nodes_names: # only configurable nodes are in this list
|
|
69
68
|
node_index_in_graph = sorted_nodes_names.index(node_name)
|
|
70
69
|
_set_node_qc(bit_widths_config,
|
|
@@ -83,7 +82,7 @@ def set_bit_widths(quant_config: QuantizationConfig,
|
|
|
83
82
|
return graph
|
|
84
83
|
|
|
85
84
|
|
|
86
|
-
def _get_node_qc_by_bit_widths(node:
|
|
85
|
+
def _get_node_qc_by_bit_widths(node: Node,
|
|
87
86
|
bit_width_cfg: List[int],
|
|
88
87
|
node_index_in_graph: int) -> Any:
|
|
89
88
|
"""
|
|
@@ -100,7 +99,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
|
100
99
|
Node quantization configuration if it was found, or None otherwise.
|
|
101
100
|
"""
|
|
102
101
|
|
|
103
|
-
if node.
|
|
102
|
+
if node.candidates_weights_quantization_cfg is not None:
|
|
104
103
|
bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
|
|
105
104
|
qc = node.candidates_weights_quantization_cfg[bit_index_in_cfg]
|
|
106
105
|
return qc
|
|
@@ -109,7 +108,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
|
109
108
|
|
|
110
109
|
def _set_node_qc(bit_width_cfg: List[int],
|
|
111
110
|
fw_info: FrameworkInfo,
|
|
112
|
-
node:
|
|
111
|
+
node: Node,
|
|
113
112
|
node_index_in_graph: int):
|
|
114
113
|
"""
|
|
115
114
|
Get the node's quantization configuration that
|
|
@@ -111,7 +111,7 @@ class MixedPrecisionSearchManager(object):
|
|
|
111
111
|
if n.name in mp_nodes:
|
|
112
112
|
node_idx = mp_nodes.index(n.name)
|
|
113
113
|
node_nbits = n.candidates_weights_quantization_cfg[mp_model_config[node_idx]].weights_n_bits
|
|
114
|
-
elif n.
|
|
114
|
+
elif n.candidates_weights_quantization_cfg is not None:
|
|
115
115
|
# The only valid way to get here is if the node is reused (which means that we're not looking
|
|
116
116
|
# for its configuration), and we ignore it when computing the KPI (as the base node will acount
|
|
117
117
|
# for it).
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from typing import List
|
|
19
19
|
|
|
20
|
-
from model_compression_toolkit import FrameworkInfo
|
|
20
|
+
from model_compression_toolkit import FrameworkInfo
|
|
21
21
|
from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
|
|
22
22
|
from model_compression_toolkit.common.graph.base_graph import Graph
|
|
23
23
|
from model_compression_toolkit.common.logger import Logger
|
|
@@ -53,17 +53,15 @@ class ModelCollector(object):
|
|
|
53
53
|
|
|
54
54
|
for n in self.graph.nodes():
|
|
55
55
|
out_stats_container = self.graph.get_out_stats_collector(n)
|
|
56
|
+
|
|
56
57
|
if isinstance(out_stats_container, list): # If layer has multiple outputs
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
mark2fetch = False
|
|
65
|
-
node2fetch.append(n) # Append node several times (as number of outputs it has)
|
|
66
|
-
stats_containers_list.append(out_stats_container)
|
|
58
|
+
mark2fetch = True
|
|
59
|
+
for sc in out_stats_container:
|
|
60
|
+
# Output only if statistics should be gathered
|
|
61
|
+
if sc.require_collection() and mark2fetch:
|
|
62
|
+
mark2fetch = False
|
|
63
|
+
node2fetch.append(n) # Append node several times (as number of outputs it has)
|
|
64
|
+
stats_containers_list.append(out_stats_container)
|
|
67
65
|
|
|
68
66
|
else: # A single output
|
|
69
67
|
if out_stats_container.require_collection():
|
|
@@ -75,8 +73,8 @@ class ModelCollector(object):
|
|
|
75
73
|
# Build a float model and output all layers' outputs
|
|
76
74
|
# (that should be collected) as the model's outputs
|
|
77
75
|
self.model, _ = self.fw_impl.model_builder(self.graph,
|
|
78
|
-
|
|
79
|
-
|
|
76
|
+
mode=ModelBuilderMode.FLOAT,
|
|
77
|
+
append2output=node2fetch,
|
|
80
78
|
fw_info=self.fw_info)
|
|
81
79
|
|
|
82
80
|
def infer(self, inputs_list: List[np.ndarray]):
|
|
@@ -102,4 +100,4 @@ class ModelCollector(object):
|
|
|
102
100
|
for tdi, sci in zip(td, sc):
|
|
103
101
|
sci.update_statistics(self.fw_impl.to_numpy(tdi))
|
|
104
102
|
else:
|
|
105
|
-
sc.update_statistics(self.fw_impl.to_numpy(td))
|
|
103
|
+
sc.update_statistics(self.fw_impl.to_numpy(td))
|