mct-nightly 1.10.0.20231211.post417__py3-none-any.whl → 1.10.0.20231213.post410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.10.0.20231211.post417
3
+ Version: 1.10.0.20231213.post410
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -32,7 +32,7 @@ model_compression_toolkit/core/common/collectors/statistics_collector_generator.
32
32
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
33
33
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=tIsWFYc771o59uvq5fxAaBmOCnd_gd-_xMbQI9SupQA,5479
34
34
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
35
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=-YBWWl3ZS7FJvZldGqT4SetlKI8j60f4sS0rYyFQpbI,30059
35
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=UHhXCWXh4hK7cg2sZrNLiTYSYsvDc3yhImcdKbQ0VVs,30929
36
36
  model_compression_toolkit/core/common/graph/base_node.py,sha256=csIgi5ex7EquQsF34w5waRIHzbg7XitvIqQgCC29azs,21118
37
37
  model_compression_toolkit/core/common/graph/edge.py,sha256=K6Wc2hBcIqig5PbbLhbjtTgYtkyZEohfgj4Wn_J5yEA,3733
38
38
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=0TpYNa2ODZ0M9lQ2z_GsStqAbrg1Muwdni74LjphAh0,2922
@@ -197,7 +197,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
197
197
  model_compression_toolkit/core/pytorch/constants.py,sha256=Kt_GDwe3yX9oMS1DI2eXYuUT25_lpjeCkxpstsAiXCI,2472
198
198
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=qee8TFcDro2lfyXe_fujjX2OlxELTyKSsLlZ7QkzeXU,4200
199
199
  model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=J0IDOtFMVFSFyBXDzNGbwJfHu89iRBJFdid1_wFB-xQ,8482
200
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=oTAd6_XYtyvTX2fRXx0BzajvgqbYreXGKD7ij8iL2SY,26482
200
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=w4yLHjJmOxfkSgApEx9rWAEbv9vkLnZik5JQvaX55FM,26654
201
201
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
202
202
  model_compression_toolkit/core/pytorch/utils.py,sha256=rBQMAbWluyIMjVfeghzq6FZv3sR_khszSRpbWvwussw,2959
203
203
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
@@ -216,6 +216,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchno
216
216
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
217
217
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
218
218
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=4mnowFmfDQjKlhHqsNto1iL4WbHyh4cM3Lf67Z-Cnzc,4804
219
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=bSexq2A7WxLLm13v67SgVbb4T1Y6nrKQDZfk4iSj_ec,3941
219
220
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py,sha256=qCNT3L4mnZtIP75c8YwImvsTWdPIdsEvO4pc3SE4y6s,5797
220
221
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=gGlVE1xUa1Pv2NbRx0I2y5Okg3kneBWSx9JwULuTWz0,38353
221
222
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/permute_call_method.py,sha256=EMCviyFyJFLEKuAUz3rZHLfB9MAU1kywSBL2XQNzLlg,1953
@@ -450,8 +451,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
450
451
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
451
452
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
452
453
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
453
- mct_nightly-1.10.0.20231211.post417.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
454
- mct_nightly-1.10.0.20231211.post417.dist-info/METADATA,sha256=krO3OedYSfl2Ck74IQ5o9-Bwi7174jQvnoy_WSEC2JE,16232
455
- mct_nightly-1.10.0.20231211.post417.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
456
- mct_nightly-1.10.0.20231211.post417.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
457
- mct_nightly-1.10.0.20231211.post417.dist-info/RECORD,,
454
+ mct_nightly-1.10.0.20231213.post410.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
455
+ mct_nightly-1.10.0.20231213.post410.dist-info/METADATA,sha256=udqlXh19QUWj8Kl3urGNDSPuK2qjV7ZEY_8OFS-gtvw,16232
456
+ mct_nightly-1.10.0.20231213.post410.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
457
+ mct_nightly-1.10.0.20231213.post410.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
458
+ mct_nightly-1.10.0.20231213.post410.dist-info/RECORD,,
@@ -299,19 +299,24 @@ class Graph(nx.MultiDiGraph, GraphSearches):
299
299
  return [edges_list.sink_node for edges_list in self.out_edges(node_obj)]
300
300
 
301
301
  def get_prev_nodes(self,
302
- node_obj: BaseNode) -> List[BaseNode]:
302
+ node_obj: BaseNode,
303
+ sink_index_sorted: bool = False) -> List[BaseNode]:
303
304
  """
304
305
  Get previous nodes (in a topological order) of a node.
305
306
 
306
307
  Args:
307
308
  node_obj: Node to get its previous nodes.
309
+ sink_index_sorted: Whether to sort the returned list by the sink_index of the edges.
308
310
 
309
311
  Returns:
310
312
  List of input nodes objects.
311
313
 
312
314
  """
313
-
314
- return [edges_list.source_node for edges_list in self.incoming_edges(node_obj)]
315
+ if sink_index_sorted:
316
+ sort_attr = 'sink_index'
317
+ else:
318
+ sort_attr = None
319
+ return [edges_list.source_node for edges_list in self.incoming_edges(node_obj, sort_by_attr=sort_attr)]
315
320
 
316
321
  def reconnect_out_edges(self,
317
322
  current_node: BaseNode,
@@ -705,3 +710,19 @@ class Graph(nx.MultiDiGraph, GraphSearches):
705
710
 
706
711
  """
707
712
  return all([n.is_all_activation_candidates_equal() for n in self.nodes])
713
+
714
+ def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
715
+ """
716
+ Replaces a node in the graph with a new node.
717
+
718
+ Args:
719
+ node_to_replace: The node to replace.
720
+ new_node: The new node to replace with.
721
+
722
+ """
723
+ self.add_node(new_node)
724
+ self.reconnect_out_edges(node_to_replace, new_node)
725
+ self.reconnect_in_edges(node_to_replace, new_node)
726
+ self.replace_output_node(node_to_replace, new_node)
727
+ self.replace_input_node(node_to_replace, new_node)
728
+ self.remove_node(node_to_replace)
@@ -0,0 +1,94 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
19
+ from model_compression_toolkit.core import common
20
+ from model_compression_toolkit.core.common import BaseNode, Graph
21
+ from model_compression_toolkit.core.pytorch.constants import *
22
+ from model_compression_toolkit.logger import Logger
23
+
24
+
25
+ class FunctionalBatchNorm(common.BaseSubstitution):
26
+ """
27
+ Replace functional batch_norm with BatchNorm2d.
28
+ """
29
+
30
+ def __init__(self):
31
+ """
32
+ Matches: functional batch_norm
33
+ """
34
+ bn_node = NodeOperationMatcher(F.batch_norm)
35
+ super().__init__(matcher_instance=bn_node)
36
+
37
+ def get_attributes_from_inputs(self, graph: Graph, node: BaseNode) -> dict:
38
+ input_nodes = graph.get_prev_nodes(node, sink_index_sorted=True)
39
+
40
+ if len(input_nodes) == 5:
41
+ return {
42
+ MOVING_MEAN: list(input_nodes[1].weights.values())[0],
43
+ MOVING_VARIANCE: list(input_nodes[2].weights.values())[0],
44
+ GAMMA: list(input_nodes[3].weights.values())[0],
45
+ BETA: list(input_nodes[4].weights.values())[0]
46
+ }
47
+ else:
48
+ Logger.warning(f'functional batch_norm is only folded in the 5 inputs case (input, mean, var, gamma, beta),'
49
+ f'got {len(input_nodes)}')
50
+ return {}
51
+
52
+ def substitute(self,
53
+ graph: Graph,
54
+ node: BaseNode) -> Graph:
55
+ """
56
+ Substitute functional.batch_norm and its inputs with BatchNorm2d.
57
+ Args:
58
+ graph: Graph we apply the substitution on.
59
+ node: node that match the pattern in the substitution init.
60
+
61
+ Returns:
62
+ Graph after applying the substitution.
63
+ """
64
+ # if the input is not a 4D tensor, we can't substitute it with BatchNorm2d
65
+ if len(node.input_shape[0]) != 4:
66
+ return graph
67
+ out_channels = node.output_shape[0][1]
68
+
69
+ bn_node_weights = self.get_attributes_from_inputs(graph, node)
70
+ if not bn_node_weights:
71
+ return graph
72
+ new_batchnorm2d = BaseNode(name=node.name + '_into_BatchNorm2d',
73
+ framework_attr={NUM_FEATURES: out_channels,
74
+ EPSILON: EPSILON_VAL,
75
+ MOMENTUM: MOMENTUM_VAL},
76
+ input_shape=node.output_shape,
77
+ output_shape=node.output_shape,
78
+ weights=bn_node_weights,
79
+ layer_class=nn.BatchNorm2d)
80
+
81
+ num_nodes_before_substitution = len(graph.nodes)
82
+ num_edges_before_substitution = len(graph.edges)
83
+
84
+ batch_norm_consts = graph.get_prev_nodes(node)[1:]
85
+ for const in batch_norm_consts:
86
+ graph.remove_edge(const, node)
87
+ graph.remove_node(const)
88
+
89
+ graph.replace_node(node, new_batchnorm2d)
90
+
91
+ assert num_nodes_before_substitution - len(graph.nodes) == len(batch_norm_consts)
92
+ assert num_edges_before_substitution - len(graph.edges) == len(batch_norm_consts)
93
+
94
+ return graph
@@ -48,6 +48,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.ba
48
48
  pytorch_batchnorm_reconstruction
49
49
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_refusing import \
50
50
  pytorch_batchnorm_refusing
51
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_batch_norm import \
52
+ FunctionalBatchNorm
51
53
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
52
54
  pytorch_linear_collapsing
53
55
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
@@ -243,7 +245,8 @@ class PytorchImplementation(FrameworkImplementation):
243
245
  return [ReshapeWithStaticShapes(),
244
246
  MultiHeadAttentionDecomposition(),
245
247
  PermuteCallMethod(),
246
- ConstantHolderConv(fw_info)]
248
+ ConstantHolderConv(fw_info),
249
+ FunctionalBatchNorm()]
247
250
 
248
251
  def get_substitutions_pre_statistics_collection(self,
249
252
  quant_config: QuantizationConfig