mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -21,12 +21,13 @@ import networkx as nx
21
21
  import numpy as np
22
22
 
23
23
  from networkx.algorithms.dag import topological_sort
24
+ from model_compression_toolkit import common
24
25
  from model_compression_toolkit.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
25
26
  from model_compression_toolkit.common.graph.edge import Edge, convert_to_edge
26
27
  from model_compression_toolkit.common.graph.graph_searches import GraphSearches
27
- from model_compression_toolkit.common.graph.base_node import BaseNode
28
- from model_compression_toolkit.common.collectors.statistics_collector import BaseStatsCollector
29
- from model_compression_toolkit.common.collectors.statistics_collector import scale_statistics, shift_statistics
28
+ from model_compression_toolkit.common.graph.node import Node
29
+ from model_compression_toolkit.common.statistics_collector import BaseStatsContainer
30
+ from model_compression_toolkit.common.statistics_collector import scale_statistics, shift_statistics
30
31
  from model_compression_toolkit.common.user_info import UserInformation
31
32
  from model_compression_toolkit.common.logger import Logger
32
33
 
@@ -39,8 +40,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
39
40
  """
40
41
 
41
42
  def __init__(self,
42
- nodes: List[BaseNode],
43
- input_nodes: List[BaseNode],
43
+ nodes: List[Node],
44
+ input_nodes: List[Node],
44
45
  output_nodes: List[OutTensor],
45
46
  edge_list: List[Edge],
46
47
  **attr):
@@ -79,7 +80,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
79
80
 
80
81
  return np.unique([n.op for n in self.nodes()])
81
82
 
82
- def get_inputs(self) -> List[BaseNode]:
83
+ def get_inputs(self) -> List[Node]:
83
84
  """
84
85
  Returns: List containing the model input nodes.
85
86
  """
@@ -94,7 +95,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
94
95
  return self.output_nodes
95
96
 
96
97
  def set_inputs(self,
97
- input_nodes: List[BaseNode]):
98
+ input_nodes: List[Node]):
98
99
  """
99
100
  Set the graph inputs dictionary.
100
101
  Args:
@@ -114,8 +115,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
114
115
  self.output_nodes = output_nodes
115
116
 
116
117
  def set_out_stats_collector_to_node(self,
117
- n: BaseNode,
118
- stats_collector: BaseStatsCollector):
118
+ n: Node,
119
+ stats_collector: BaseStatsContainer):
119
120
  """
120
121
  Set an output statistics collector of a node in the graph, and set this statistics collector as an input
121
122
  statistics collector of nodes next to this given node.
@@ -163,7 +164,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
163
164
  self.node_to_in_stats_collector.update({oe.sink_node: stats_collector})
164
165
 
165
166
  def get_out_stats_collector(self,
166
- n: BaseNode) -> BaseStatsCollector:
167
+ n: Node) -> BaseStatsContainer:
167
168
  """
168
169
  Get the output statistics collector of a node containing output statistics of the node.
169
170
  Args:
@@ -172,10 +173,11 @@ class Graph(nx.MultiDiGraph, GraphSearches):
172
173
  Returns:
173
174
  Tensor containing output statistics of the node.
174
175
  """
176
+
175
177
  return self.node_to_out_stats_collector.get(n)
176
178
 
177
179
  def get_in_stats_collector(self,
178
- n: BaseNode) -> BaseStatsCollector:
180
+ n: Node) -> BaseStatsContainer:
179
181
  """
180
182
  Get the input statistics collector of a node containing input statistics of the node.
181
183
  Args:
@@ -191,7 +193,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
191
193
  return sc
192
194
 
193
195
  def scale_stats_collector(self,
194
- node: BaseNode,
196
+ node: Node,
195
197
  scale_factor: np.ndarray):
196
198
  """
197
199
  Scale the output statistics of a node in the graph by a given scaling factor.
@@ -209,7 +211,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
209
211
  self.set_out_stats_collector_to_node(node, scaled_sc)
210
212
 
211
213
  def shift_stats_collector(self,
212
- node: BaseNode,
214
+ node: Node,
213
215
  shift_value: np.ndarray):
214
216
  """
215
217
  Shift the output statistics of a node in the graph by a given value.
@@ -227,7 +229,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
227
229
  self.set_out_stats_collector_to_node(node, shifted_sc)
228
230
 
229
231
  def find_node_by_name(self,
230
- name: str) -> List[BaseNode]:
232
+ name: str) -> List[Node]:
231
233
  """
232
234
  Find and return a list of nodes by a name.
233
235
 
@@ -241,7 +243,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
241
243
  return [n for n in self.nodes if n.name == name]
242
244
 
243
245
  def get_next_nodes(self,
244
- node_obj: BaseNode) -> List[BaseNode]:
246
+ node_obj: Node) -> List[Node]:
245
247
  """
246
248
  Get next nodes (in a topological order) of a node.
247
249
 
@@ -256,7 +258,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
256
258
  return [edges_list.sink_node for edges_list in self.out_edges(node_obj)]
257
259
 
258
260
  def get_prev_nodes(self,
259
- node_obj: BaseNode) -> List[BaseNode]:
261
+ node_obj: Node) -> List[Node]:
260
262
  """
261
263
  Get previous nodes (in a topological order) of a node.
262
264
 
@@ -271,8 +273,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
271
273
  return [edges_list.source_node for edges_list in self.incoming_edges(node_obj)]
272
274
 
273
275
  def reconnect_out_edges(self,
274
- current_node: BaseNode,
275
- new_node: BaseNode):
276
+ current_node: Node,
277
+ new_node: Node):
276
278
  """
277
279
  Connect all outgoing edges of a node to be outgoing edges of a different node
278
280
  (useful when replacing a node during substitutions).
@@ -287,8 +289,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
287
289
  self.remove_edge(current_node, oe.sink_node)
288
290
 
289
291
  def reconnect_in_edges(self,
290
- current_node: BaseNode,
291
- new_node: BaseNode):
292
+ current_node: Node,
293
+ new_node: Node):
292
294
  """
293
295
  Connect all incoming edges of a node to be incoming edges of a different node
294
296
  (useful when replacing a node during substitutions).
@@ -303,8 +305,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
303
305
  self.remove_edge(ie.source_node, current_node)
304
306
 
305
307
  def replace_output_node(self,
306
- current_node: BaseNode,
307
- new_node: BaseNode):
308
+ current_node: Node,
309
+ new_node: Node):
308
310
  """
309
311
  If a node is being substituted with another node and it is an output node, the graph's outputs
310
312
  should be updated as well. This function takes care of it by going over the graph's outputs, and
@@ -325,8 +327,8 @@ class Graph(nx.MultiDiGraph, GraphSearches):
325
327
  self.set_outputs(new_graph_outputs)
326
328
 
327
329
  def remove_node(self,
328
- node_to_remove: BaseNode,
329
- new_graph_inputs: List[BaseNode] = None,
330
+ node_to_remove: Node,
331
+ new_graph_inputs: List[Node] = None,
330
332
  new_graph_outputs: List[OutTensor] = None):
331
333
  """
332
334
  Remove a node from the graph. A new inputs/outputs lists can be passed in case the node is currently an
@@ -365,7 +367,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
365
367
  super().remove_node(node_to_remove)
366
368
 
367
369
  def incoming_edges(self,
368
- n: BaseNode,
370
+ n: Node,
369
371
  sort_by_attr: str = None) -> List[Edge]:
370
372
  """
371
373
  Get a list of incoming edges of a node. If sort_by_attr is passed, the returned list
@@ -386,7 +388,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
386
388
  return input_edges
387
389
 
388
390
  def out_edges(self,
389
- n: BaseNode,
391
+ n: Node,
390
392
  sort_by_attr: str = None) -> List[Edge]:
391
393
  """
392
394
  Get a list of outgoing edges of a node. If sort_by_attr is passed, the returned list
@@ -416,8 +418,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
416
418
  memory += n.get_memory_bytes()
417
419
  return memory
418
420
 
419
- def get_configurable_sorted_nodes_names(self,
420
- include_reused_nodes: bool = False) -> List[str]:
421
+ def get_configurable_sorted_nodes_names(self, include_reused_nodes: bool = False) -> List[str]:
421
422
  """
422
423
  Get a list of nodes' names that can be configured (namely, has one or
423
424
  more weight qc candidate). The names are sorted according to the topological
@@ -430,15 +431,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
430
431
  more weight qc candidate) sorted topology.
431
432
 
432
433
  """
433
- sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)]
434
+ sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes)]
434
435
  return sorted_names
435
436
 
436
- def get_configurable_sorted_nodes(self,
437
- include_reused_nodes: bool = False) -> List[BaseNode]:
437
+ def get_configurable_sorted_nodes(self, include_reused_nodes: bool = False) -> List[Node]:
438
438
  """
439
439
  Get a list of nodes that can be configured (namely, has one or
440
- more weight qc candidate and their weights should be quantized).
441
- The nodes are sorted according to the topological order of the graph.
440
+ more weight qc candidate). The nodes are sorted according to the topological
441
+ order of the graph.
442
442
 
443
443
  Args:
444
444
  include_reused_nodes: Whether or not to include reused nodes (False by default).
@@ -450,7 +450,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
450
450
  sorted_configurable_nodes = []
451
451
  sorted_nodes = list(topological_sort(self))
452
452
  for n in sorted_nodes:
453
- if n.is_weights_quantization_enabled():
453
+ if n.candidates_weights_quantization_cfg is not None:
454
454
  if not n.reuse or include_reused_nodes:
455
455
  if len(n.candidates_weights_quantization_cfg) >= 1:
456
456
  sorted_configurable_nodes.append(n)
@@ -16,7 +16,7 @@
16
16
 
17
17
  from typing import Any, Dict
18
18
 
19
- from model_compression_toolkit.common.graph.base_node import BaseNode
19
+ from model_compression_toolkit.common.graph.node import Node
20
20
 
21
21
  # Edge attributes:
22
22
  EDGE_SOURCE_INDEX = 'source_index'
@@ -29,8 +29,8 @@ class Edge(object):
29
29
  """
30
30
 
31
31
  def __init__(self,
32
- source_node: BaseNode,
33
- sink_node: BaseNode,
32
+ source_node: Node,
33
+ sink_node: Node,
34
34
  source_index: int,
35
35
  sink_index: int):
36
36
  """
@@ -18,7 +18,7 @@ from typing import Any, List
18
18
 
19
19
  from tensorflow.keras.layers import Layer
20
20
 
21
- from model_compression_toolkit.common.graph.base_node import BaseNode
21
+ from model_compression_toolkit.common.graph.node import Node
22
22
  from model_compression_toolkit.common.matchers import node_matcher, walk_matcher, edge_matcher
23
23
 
24
24
 
@@ -92,7 +92,7 @@ class EdgeMatcher(edge_matcher.BaseEdgeMatcher):
92
92
  class EdgeMatcher to check if an edge matches an edge that EdgeMatcher contains.
93
93
  """
94
94
 
95
- def __init__(self, source_matcher: BaseNode, target_matcher: BaseNode):
95
+ def __init__(self, source_matcher: Node, target_matcher: Node):
96
96
  """
97
97
  Init an EdgeMatcher object.
98
98
 
@@ -125,7 +125,7 @@ class WalkMatcher(walk_matcher.WalkMatcherList):
125
125
  Class WalkMatcher to check if a list of nodes matches another list of nodes.
126
126
  """
127
127
 
128
- def __init__(self, matcher_list: List[BaseNode]):
128
+ def __init__(self, matcher_list: List[Node]):
129
129
  """
130
130
  Init a WalkMatcher object.
131
131
 
@@ -16,7 +16,7 @@
16
16
  from abc import ABC
17
17
  from typing import List, Any
18
18
 
19
- from model_compression_toolkit.common.graph.base_node import BaseNode
19
+ from model_compression_toolkit.common.graph.node import Node
20
20
  from model_compression_toolkit.common.matchers import node_matcher, base_graph_filter, edge_matcher
21
21
  from model_compression_toolkit.common.matchers.walk_matcher import WalkMatcherList
22
22
 
@@ -59,7 +59,7 @@ class GraphSearches(base_graph_filter.BaseGraphFilter, ABC):
59
59
 
60
60
  return edge_list
61
61
 
62
- def _walk_filter(self, walk_matcher: WalkMatcherList) -> List[BaseNode]:
62
+ def _walk_filter(self, walk_matcher: WalkMatcherList) -> List[Node]:
63
63
  """
64
64
  Search for a list of nodes which match the list in walk_matcher.
65
65
  If one the nodes in the list (that was found in the graph) has more than one output,
@@ -72,8 +72,8 @@ class GraphSearches(base_graph_filter.BaseGraphFilter, ABC):
72
72
  A list of nodes which match the list in walk_matcher.
73
73
  """
74
74
 
75
- def walk_match(node: BaseNode,
76
- node_list: List[BaseNode],
75
+ def walk_match(node: Node,
76
+ node_list: List[Node],
77
77
  index: int,
78
78
  node_matcher_list: list) -> Any:
79
79
  """
@@ -0,0 +1,116 @@
1
+ # Copyright 2021 Sony Semiconductors Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Tuple
20
+
21
+ from model_compression_toolkit import common
22
+ from model_compression_toolkit.common.graph.node import Node
23
+
24
+
25
+ def check_str(in_sting: str) -> bool:
26
+ """
27
+ Checks if an open bracket is in a string.
28
+
29
+ Args:
30
+ in_sting: String to check.
31
+
32
+ Returns:
33
+ Whether an open bracket is in the string or not.
34
+ """
35
+
36
+ return '(' not in in_sting and '[' not in in_sting and '{' not in in_sting
37
+
38
+
39
+ def node_dict(n: Node) -> Dict[str, Any]:
40
+ """
41
+ Get a dictionary with a node's attributes for displaying.
42
+
43
+ Args:
44
+ n: Node to get its attributes to display.
45
+
46
+ Returns:
47
+ A dictionary with params of the node to display when visualizing the graph.
48
+ """
49
+
50
+ framework_attr = {k: str(v) for k, v in n.framework_attr.items() if check_str(str(v))}
51
+ framework_attr.update({k: str(v) for k, v in n.quantization_attr.items() if check_str(str(v))})
52
+ framework_attr.update({'op': n.layer_class.__name__})
53
+
54
+ if n.quantization_cfg is not None:
55
+ for k, v in n.quantization_cfg.activation_quantization_params.items():
56
+ framework_attr.update({k: str(v)})
57
+ framework_attr.update({'activation_is_signed': str(n.quantization_cfg.activation_is_signed)})
58
+
59
+ return {"id": n.name,
60
+ "group": n.layer_class.__name__,
61
+ "label": n.layer_class.__name__,
62
+ "title": "",
63
+ "properties": framework_attr}
64
+
65
+
66
+ def edge_dict(i: int,
67
+ edge: Tuple[Node, Node],
68
+ graph):
69
+ """
70
+ Create a dictionary of attributes to visualize an edge in the graph.
71
+
72
+ Args:
73
+ i: Edge's ID.
74
+ edge: Tuple of two nodes (source and destination).
75
+ graph: Graph the edge is in.
76
+
77
+ Returns:
78
+ Dictionary of attributes to visualize an edge in a graph.
79
+ """
80
+
81
+ return {"id": str(i),
82
+ "from": edge[0].name,
83
+ "to": edge[1].name,
84
+ "label": "",
85
+ "title": "",
86
+ "properties": graph.get_edge_data(edge[0], edge[1])}
87
+
88
+
89
+ def write_vis_graph(folder_path: str,
90
+ file_name: str,
91
+ graph):
92
+ """
93
+ Create and save a json file containing data about the graph to visualize (such as nodes and edges).
94
+ folder_path and file_name determine where the json file will be saved.
95
+
96
+ Args:
97
+ folder_path: Dir path to save the json file.
98
+ file_name: File name of the json file.
99
+ graph: Graph to visualize.
100
+
101
+ """
102
+
103
+ os.makedirs(folder_path, exist_ok=True)
104
+ file_path = os.path.join(folder_path, file_name + '.json')
105
+
106
+ nodes_list = [node_dict(n) for n in graph.nodes()]
107
+ edges_list = [edge_dict(i, e, graph) for i, e in enumerate(graph.edges())]
108
+ graph_dict = {'nodes': nodes_list,
109
+ 'edges': edges_list,
110
+ "options": {}
111
+ }
112
+
113
+ with open(file_path, 'w') as json_file:
114
+ json.dump(graph_dict, json_file)
115
+
116
+ common.Logger.info(f"Writing Vis Graph to:{file_path}")
@@ -21,7 +21,7 @@ import numpy as np
21
21
  from model_compression_toolkit.common.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE
22
22
 
23
23
 
24
- class BaseNode:
24
+ class Node:
25
25
  """
26
26
  Class to represent a node in a graph that represents the model.
27
27
  """
@@ -35,6 +35,7 @@ class BaseNode:
35
35
  layer_class: type,
36
36
  reuse: bool = False,
37
37
  reuse_group: str = None,
38
+ op_call_args: Dict[str, Any] = {},
38
39
  quantization_attr: Dict[str, Any] = None):
39
40
  """
40
41
  Init a Node object.
@@ -48,8 +49,11 @@ class BaseNode:
48
49
  layer_class: Class path of the layer this node represents.
49
50
  reuse: Whether this node was duplicated and represents a reused layer.
50
51
  reuse_group: Name of group of nodes from the same reused layer.
52
+ op_call_args: Arguments dictionary with values to pass when calling the layer.
51
53
  quantization_attr: Attributes the node holds regarding how it should be quantized.
52
54
  """
55
+
56
+
53
57
  self.name = name
54
58
  self.framework_attr = framework_attr
55
59
  self.quantization_attr = quantization_attr if quantization_attr is not None else dict()
@@ -62,25 +66,37 @@ class BaseNode:
62
66
  self.activation_quantization_cfg = None
63
67
  self.final_weights_quantization_cfg = None
64
68
  self.candidates_weights_quantization_cfg = None
65
- self.prior_info = None
69
+ self.output_quantization = True
70
+ self.op_call_args = op_call_args
71
+
72
+ def no_quantization(self) -> bool:
73
+ """
74
+
75
+ Returns: Whether NodeQuantizationConfig does not have activation params.
76
+
77
+ """
78
+ return self.activation_quantization_cfg is None or \
79
+ (not self.activation_quantization_cfg.has_activation_quantization_params())
66
80
 
67
- def is_activation_quantization_enabled(self) -> bool:
81
+ def weight_quantization(self) -> bool:
68
82
  """
69
83
 
70
- Returns: Whether node activation quantization is enabled or not.
84
+ Returns: Whether node weights should be quantized
71
85
 
72
86
  """
73
- return self.activation_quantization_cfg.enable_activation_quantization
87
+ return self.final_weights_quantization_cfg is not None and \
88
+ self.final_weights_quantization_cfg.has_weights_quantization_params() and \
89
+ self.final_weights_quantization_cfg.enable_weights_quantization
74
90
 
75
- def is_weights_quantization_enabled(self) -> bool:
91
+ def activation_quantization(self) -> bool:
76
92
  """
77
93
 
78
- Returns: Whether node weights quantization is enabled or not.
94
+ Returns: Whether node activation should be quantized
79
95
 
80
96
  """
81
- for qc in self.candidates_weights_quantization_cfg:
82
- assert self.candidates_weights_quantization_cfg[0].enable_weights_quantization == qc.enable_weights_quantization
83
- return self.candidates_weights_quantization_cfg[0].enable_weights_quantization
97
+ return self.activation_quantization_cfg is not None and \
98
+ self.activation_quantization_cfg.has_activation_quantization_params() and \
99
+ self.activation_quantization_cfg.enable_activation_quantization
84
100
 
85
101
  def __repr__(self):
86
102
  """
@@ -172,7 +188,7 @@ class BaseNode:
172
188
  """
173
189
  shared_attributes = [CORRECTED_BIAS_ATTRIBUTE, WEIGHTS_NBITS_ATTRIBUTE]
174
190
  attr = dict()
175
- if self.is_weights_quantization_enabled():
191
+ if self.candidates_weights_quantization_cfg is not None:
176
192
  attr = copy.deepcopy(self.candidates_weights_quantization_cfg[0].__dict__)
177
193
  for shared_attr in shared_attributes:
178
194
  if shared_attr in attr:
@@ -18,13 +18,14 @@ import copy
18
18
  from typing import Any, List
19
19
 
20
20
  from model_compression_toolkit.common.quantization.quantization_config import QuantizationConfig
21
- from model_compression_toolkit.common import Graph, BaseNode
21
+ from model_compression_toolkit.common import Graph, Node
22
22
  from model_compression_toolkit.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.common.logger import Logger
24
24
  from model_compression_toolkit.common.mixed_precision.mixed_precision_quantization_config import \
25
25
  MixedPrecisionQuantizationConfig
26
26
 
27
27
 
28
+
28
29
  def set_bit_widths(quant_config: QuantizationConfig,
29
30
  graph_to_set_bit_widths: Graph,
30
31
  fw_info: FrameworkInfo = None,
@@ -36,10 +37,9 @@ def set_bit_widths(quant_config: QuantizationConfig,
36
37
  Args:
37
38
  quant_config: MixedPrecisionQuantizationConfig the graph was computed according to.
38
39
  graph_to_set_bit_widths: A prepared for quantization graph to set its bit widths.
39
- fw_info: Information needed for quantization about the specific framework (e.g., kernel
40
- channels indices, groups of layers by how they should be quantized, etc.)
41
- bit_widths_config: MP configuration (a list of indices: one for each node's candidate
42
- quantization configuration).
40
+ fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
41
+ groups of layers by how they should be quantized, etc.)
42
+ bit_widths_config: MP configuration (a list of indices: one for each node's candidate quantization configuration).
43
43
 
44
44
  """
45
45
  graph = copy.deepcopy(graph_to_set_bit_widths)
@@ -63,8 +63,7 @@ def set_bit_widths(quant_config: QuantizationConfig,
63
63
  # Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
64
64
  sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
65
65
  for node in graph.nodes: # set a specific node qc for each node final weights qc
66
- # If it's reused, take the configuration that the base node has
67
- node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
66
+ node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2]) # if it's reused, take the configuration that the base node has
68
67
  if node_name in sorted_nodes_names: # only configurable nodes are in this list
69
68
  node_index_in_graph = sorted_nodes_names.index(node_name)
70
69
  _set_node_qc(bit_widths_config,
@@ -83,7 +82,7 @@ def set_bit_widths(quant_config: QuantizationConfig,
83
82
  return graph
84
83
 
85
84
 
86
- def _get_node_qc_by_bit_widths(node: BaseNode,
85
+ def _get_node_qc_by_bit_widths(node: Node,
87
86
  bit_width_cfg: List[int],
88
87
  node_index_in_graph: int) -> Any:
89
88
  """
@@ -100,7 +99,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
100
99
  Node quantization configuration if it was found, or None otherwise.
101
100
  """
102
101
 
103
- if node.is_weights_quantization_enabled():
102
+ if node.candidates_weights_quantization_cfg is not None:
104
103
  bit_index_in_cfg = bit_width_cfg[node_index_in_graph]
105
104
  qc = node.candidates_weights_quantization_cfg[bit_index_in_cfg]
106
105
  return qc
@@ -109,7 +108,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
109
108
 
110
109
  def _set_node_qc(bit_width_cfg: List[int],
111
110
  fw_info: FrameworkInfo,
112
- node: BaseNode,
111
+ node: Node,
113
112
  node_index_in_graph: int):
114
113
  """
115
114
  Get the node's quantization configuration that
@@ -111,7 +111,7 @@ class MixedPrecisionSearchManager(object):
111
111
  if n.name in mp_nodes:
112
112
  node_idx = mp_nodes.index(n.name)
113
113
  node_nbits = n.candidates_weights_quantization_cfg[mp_model_config[node_idx]].weights_n_bits
114
- elif n.is_weights_quantization_enabled():
114
+ elif n.candidates_weights_quantization_cfg is not None:
115
115
  # The only valid way to get here is if the node is reused (which means that we're not looking
116
116
  # for its configuration), and we ignore it when computing the KPI (as the base node will acount
117
117
  # for it).
@@ -17,7 +17,7 @@
17
17
  import numpy as np
18
18
  from typing import List
19
19
 
20
- from model_compression_toolkit import FrameworkInfo, common
20
+ from model_compression_toolkit import FrameworkInfo
21
21
  from model_compression_toolkit.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.common.graph.base_graph import Graph
23
23
  from model_compression_toolkit.common.logger import Logger
@@ -53,17 +53,15 @@ class ModelCollector(object):
53
53
 
54
54
  for n in self.graph.nodes():
55
55
  out_stats_container = self.graph.get_out_stats_collector(n)
56
+
56
57
  if isinstance(out_stats_container, list): # If layer has multiple outputs
57
- # Append nodes to output and track their statistics only if
58
- # they actually collect statistics.
59
- if len([x for x in out_stats_container if not isinstance(x, common.NoStatsCollector)]) > 0:
60
- mark2fetch = True
61
- for sc in out_stats_container:
62
- # Output only if statistics should be gathered
63
- if sc.require_collection() and mark2fetch:
64
- mark2fetch = False
65
- node2fetch.append(n) # Append node several times (as number of outputs it has)
66
- stats_containers_list.append(out_stats_container)
58
+ mark2fetch = True
59
+ for sc in out_stats_container:
60
+ # Output only if statistics should be gathered
61
+ if sc.require_collection() and mark2fetch:
62
+ mark2fetch = False
63
+ node2fetch.append(n) # Append node several times (as number of outputs it has)
64
+ stats_containers_list.append(out_stats_container)
67
65
 
68
66
  else: # A single output
69
67
  if out_stats_container.require_collection():
@@ -75,8 +73,8 @@ class ModelCollector(object):
75
73
  # Build a float model and output all layers' outputs
76
74
  # (that should be collected) as the model's outputs
77
75
  self.model, _ = self.fw_impl.model_builder(self.graph,
78
- mode=ModelBuilderMode.FLOAT,
79
- append2output=node2fetch,
76
+ mode=ModelBuilderMode.FLOAT,
77
+ append2output=node2fetch,
80
78
  fw_info=self.fw_info)
81
79
 
82
80
  def infer(self, inputs_list: List[np.ndarray]):
@@ -102,4 +100,4 @@ class ModelCollector(object):
102
100
  for tdi, sci in zip(td, sc):
103
101
  sci.update_statistics(self.fw_impl.to_numpy(tdi))
104
102
  else:
105
- sc.update_statistics(self.fw_impl.to_numpy(td))
103
+ sc.update_statistics(self.fw_impl.to_numpy(td))