mct-nightly 2.2.0.20240925.453__py3-none-any.whl → 2.2.0.20240926.452__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20240925.453
3
+ Version: 2.2.0.20240926.452
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=eI5QONSxEPDPEpc5TqZF43AfSWl8Om61_pe6SgWyCTk,1573
1
+ model_compression_toolkit/__init__.py,sha256=vX11_K5A8c4_uT3X2dHKRg0nxBh-qKBSqljT0u_1B64,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -47,7 +47,7 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
47
47
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
48
48
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
49
49
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=6216QgHl7h4DXGn5ForP9Tija-wrBSONNtQ769ikP2s,1025
50
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=DHbZqFDuDir1QWN-YkYBzaoGDujgYam1hT2ea6uL3yM,21009
50
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=TfgSIh5pmZcJM9335aAxZriCzMljnk3mYhmKBsK2x5Y,20848
51
51
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
52
52
  model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
53
53
  model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=atGJgJBL9uwYRC3t9NnzGgHYxV4XJj4Ai_xPpQH0rhY,3229
@@ -228,7 +228,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
228
228
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
229
229
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
230
230
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
231
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=QfMulTLC6X_0Iwvk_VChFIoSdeoiEeJ_rf2IQi5TjBk,19353
231
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=BJeKGMv5VU4Z3jLOIQ-Ifs_2vGELQSmEQmje3ZmaUl4,19948
232
232
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
233
233
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
234
234
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -239,7 +239,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchno
239
239
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
240
240
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
241
241
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
242
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=iTuP1hjuTZTGcE7izfs_UOWBGeEBFRvRIU4QCh-b21M,4627
242
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=sw3jIOUSvfWUeD8l3rGcUOtC6QuzpMIQm8V3RQAM53Q,4741
243
243
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=7GZY7lU3LUUaO5iiccHkUP62PB0QeGAGOZdUSGMkFBY,4450
244
244
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=XhiLVcnCc_gF-6mjxbf9C4bYg5YL_GCvDJmcdLkBNAg,4151
245
245
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_collapsing.py,sha256=CXSMASpc_Zed3BJ2CsER69zKxE6ncFvvKQWDO1JxKYI,5849
@@ -267,7 +267,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
267
267
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
268
268
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
269
269
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
270
- model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=BvBj9uokKTvX-6d39yA4SKwRQAN8_X4T8l-rPibChJQ,16754
270
+ model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=mo1NIYXxiAigbTZvNgQeLi6vzLn0RqU0RxcxZKE27cE,19335
271
271
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZPJGcOESvtAwfIMxrU6kvt3YjF5B7qOqK4,1048
272
272
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
273
273
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -301,7 +301,7 @@ model_compression_toolkit/data_generation/pytorch/image_operations.py,sha256=KUQ
301
301
  model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=dcQr-67u9-ggGuS39YAvR7z-Y0NOdJintcVQ5vy1bM8,7478
302
302
  model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=y6vMed6lQQj67-BXZKrAcWUNTkH8YjiUhknOV4wSpRA,9399
303
303
  model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=vRMeUEdInPuJisiO-SKo_9miWZV90sz8GCg5MY0AqiU,18098
304
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=cUkFg-9LWwRKy11tlASJwp1FbDx6a7sZWpJNMz01hWA,21626
304
+ model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=_BFy4RYcLoxpt5KecM5VbPRRNM4QHdFr9WmtL4FODUE,21796
305
305
  model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
306
306
  model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
307
307
  model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=We0fVMQ4oU7Y0IWQ8fKy8KpqkIiLyKoQeF9XKAQ6TH0,3317
@@ -472,9 +472,9 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_
472
472
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py,sha256=XM6qBLIvzsmdFf-AZq5WOlORK2GXC_X-gulReNxHb9E,6601
473
473
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=nP05jqvh6uaj30a3W7zEkJfKtqfP0Nz5bobwRqbYrdM,5807
474
474
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py,sha256=tHTUvsaerSfbe22pU0kIDauPpFD7Pq5EmZytVIDkHz4,717
475
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py,sha256=_LHQkGB0x12FQBDIkEA-Br8HSUL5ZmMXxI7lDpVWcQU,15422
475
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py,sha256=Ee7M3YVymdv6HYsm7coB8N0dyTOhlAhLdxfSLJXCuoU,15665
476
476
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py,sha256=u8qD1XkHwU4LIoNbmC5mtZd8lZ8gZ4XFihZmoYwAulc,7641
477
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py,sha256=EaQERA7XnZcF7pO4xzBk0li96JnACRE7ppgK535EMXM,6698
477
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py,sha256=GCghKkkZOKNTAzwyoZZPid9alGiufNUBzDj2yE7YUSU,6709
478
478
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
479
479
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=is00rNrDmmirYsyMtMkWz0DwOA92-x7hAJwpd6z1n2E,2806
480
480
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=CXC-HQolSDu7j8V-Xm-SWGCd74gXB3XnAkEhI_TVbIQ,1516
@@ -550,8 +550,8 @@ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2
550
550
  tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
551
551
  tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
552
552
  tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
553
- mct_nightly-2.2.0.20240925.453.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
554
- mct_nightly-2.2.0.20240925.453.dist-info/METADATA,sha256=-FYOCuanQ2MY1g5nj-LWf1WuTJEAqIpR5ymgLPAGe2I,20813
555
- mct_nightly-2.2.0.20240925.453.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
556
- mct_nightly-2.2.0.20240925.453.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
557
- mct_nightly-2.2.0.20240925.453.dist-info/RECORD,,
553
+ mct_nightly-2.2.0.20240926.452.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
554
+ mct_nightly-2.2.0.20240926.452.dist-info/METADATA,sha256=AWRxoKCjgmTLCf726oR4aaHZBpcxel0TrFkAGP-5guM,20813
555
+ mct_nightly-2.2.0.20240926.452.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
556
+ mct_nightly-2.2.0.20240926.452.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
557
+ mct_nightly-2.2.0.20240926.452.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20240925.000453"
30
+ __version__ = "2.2.0.20240926.000452"
@@ -258,12 +258,10 @@ class HessianInfoService:
258
258
  f"{hessian_scores_request.target_nodes}.")
259
259
 
260
260
  # Replace node in reused target nodes with a representing node from the 'reuse group'.
261
- for n in hessian_scores_request.target_nodes:
262
- if n.reuse_group:
263
- rep_node = self._get_representing_of_reuse_group(n)
264
- hessian_scores_request.target_nodes.remove(n)
265
- if rep_node not in hessian_scores_request.target_nodes:
266
- hessian_scores_request.target_nodes.append(rep_node)
261
+ hessian_scores_request.target_nodes = [
262
+ self._get_representing_of_reuse_group(node) if node.reuse else node
263
+ for node in hessian_scores_request.target_nodes
264
+ ]
267
265
 
268
266
  # Ensure the saved info has the required number of approximations
269
267
  self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
@@ -231,6 +231,7 @@ class PytorchModel(torch.nn.Module):
231
231
  self.return_float_outputs = return_float_outputs
232
232
  self.wrapper = wrapper
233
233
  self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
234
+ self.reuse_groups = {}
234
235
  self._add_modules()
235
236
 
236
237
  # todo: Move to parent class BaseModelBuilder
@@ -288,7 +289,19 @@ class PytorchModel(torch.nn.Module):
288
289
  Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
289
290
  """
290
291
  for node in self.node_sort:
291
- node_op = self.wrap(node)
292
+ if node.reuse:
293
+ # If the node is reused, retrieve the original module
294
+ if node.reuse_group not in self.reuse_groups:
295
+ Logger.critical(f"Reuse group {node.reuse_group} not found for node {node.name}")
296
+
297
+ node_op = self.reuse_groups[node.reuse_group]
298
+ else:
299
+ # If it's not reused, create a new module
300
+ node_op = self.wrap(node)
301
+ if node.reuse_group:
302
+ # Store the module for future reuse
303
+ self.reuse_groups[node.reuse_group] = node_op
304
+
292
305
  if isinstance(node, FunctionalNode):
293
306
  # for functional layers
294
307
  setattr(self, node.name, node_op)
@@ -80,7 +80,9 @@ class FunctionalConvSubstitution(common.BaseSubstitution):
80
80
  output_shape=func_node.output_shape,
81
81
  weights={KERNEL: weight} if bias is None else {KERNEL: weight, BIAS: bias},
82
82
  layer_class=new_layer,
83
- has_activation=func_node.has_activation)
83
+ has_activation=func_node.has_activation,
84
+ reuse=func_node.reuse,
85
+ reuse_group=func_node.reuse_group)
84
86
  graph.add_node(new_node)
85
87
  graph.reconnect_out_edges(current_node=func_node, new_node=new_node)
86
88
  graph.reconnect_in_edges(current_node=func_node, new_node=new_node)
@@ -30,8 +30,7 @@ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlac
30
30
  from model_compression_toolkit.logger import Logger
31
31
 
32
32
 
33
- def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
34
- to_numpy: Callable) -> Dict[str, np.ndarray]:
33
+ def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule]) -> Dict[str, np.ndarray]:
35
34
  """
36
35
  Extract parameters & buffers from input module to a dictionary.
37
36
  Args:
@@ -41,8 +40,8 @@ def _extract_parameters_and_buffers(module: Union[torch.nn.Module, GraphModule],
41
40
  Dictionary containing module parameters and buffers by name.
42
41
  """
43
42
 
44
- named_parameters = {name: to_numpy(parameter) for name, parameter in module.named_parameters()}
45
- named_buffers = {name: to_numpy(buffer) for name, buffer in module.named_buffers()}
43
+ named_parameters = {name: parameter for name, parameter in module.named_parameters()}
44
+ named_buffers = {name: buffer for name, buffer in module.named_buffers()}
46
45
 
47
46
  return {**named_parameters, **named_buffers}
48
47
 
@@ -97,14 +96,12 @@ def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict
97
96
  return op_call_args, tensor_input_alloc
98
97
 
99
98
 
100
- def _extract_torch_layer_data(node_module: torch.nn.Module,
101
- to_numpy: Callable) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
99
+ def _extract_torch_layer_data(node_module: torch.nn.Module) -> Tuple[Any, Dict[str, np.ndarray], Dict]:
102
100
  """
103
101
  Extract required data from a non-functional node to rebuild the PyTorch layer.
104
102
 
105
103
  Args:
106
104
  node_module: Torch layer, such as nn.Conv2d, nn.Linear, etc.
107
- to_numpy: Function to convert framework's tensor to a Numpy array.
108
105
 
109
106
  Returns:
110
107
  Node layer class.
@@ -124,7 +121,7 @@ def _extract_torch_layer_data(node_module: torch.nn.Module,
124
121
  framework_attr[BIAS] = False if node_module.bias is None else True
125
122
 
126
123
  # Extract layer weights and named buffers.
127
- weights = {n: w for n, w in _extract_parameters_and_buffers(node_module, to_numpy).items() if len(w.shape) > 0}
124
+ weights = {n: w for n, w in _extract_parameters_and_buffers(node_module).items() if len(w.shape) > 0}
128
125
  return node_type, weights, framework_attr
129
126
 
130
127
 
@@ -181,8 +178,11 @@ def nodes_builder(model: GraphModule,
181
178
  consts_dict = {}
182
179
  used_consts = set()
183
180
 
181
+ # Dictionary to track seen targets and their corresponding nodes to mark reused nodes
182
+ seen_targets = {}
183
+
184
184
  # Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
185
- model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)
185
+ model_parameters_and_buffers = _extract_parameters_and_buffers(model)
186
186
 
187
187
  for node in model.graph.nodes:
188
188
 
@@ -195,7 +195,7 @@ def nodes_builder(model: GraphModule,
195
195
 
196
196
  if node.target in module_dict.keys():
197
197
  # PyTorch module node, such as nn.Conv2d or nn.Linear.
198
- node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target], to_numpy)
198
+ node_type, weights, framework_attr = _extract_torch_layer_data(module_dict[node.target])
199
199
 
200
200
  elif node.op == CALL_FUNCTION:
201
201
  # Node is a function that handle a parameter\buffer in the model.
@@ -249,6 +249,31 @@ def nodes_builder(model: GraphModule,
249
249
  # Extract input and output shapes of the node.
250
250
  input_shape, output_shape = _extract_input_and_output_shapes(node)
251
251
 
252
+ # Check if this node's target has been seen before
253
+ reuse = False
254
+ reuse_group = None
255
+ node_group_key = create_reuse_group(node.target, weights)
256
+ # We mark nodes as reused only if there are multiple nodes in the graph with same
257
+ # 'target' and it has some weights.
258
+ if node_group_key in seen_targets and len(weights) > 0:
259
+ reuse = True
260
+ reuse_group = node_group_key
261
+ # Update the 'base/main' node with the reuse group as all other nodes in its group.
262
+ fx_node_2_graph_node[seen_targets[node_group_key]].reuse_group = reuse_group
263
+ else:
264
+ seen_targets[node_group_key] = node
265
+
266
+ # Convert weights to numpy arrays after reuse marking
267
+ # We delay this conversion to preserve the original tensor instances during the reuse identification process.
268
+ # This is crucial for correctly identifying identical weight instances in reused functional layers.
269
+ # By keeping the original PyTorch tensors until this point, we ensure that:
270
+ # 1. Reused layers with the same weight instances are correctly marked as reused.
271
+ # 2. The instance-based weight signature generation works as intended, using the memory
272
+ # addresses of the original tensors.
273
+ # Only after all reuse marking is complete do we convert to numpy arrays.
274
+ for weight_name, weight_value in weights.items():
275
+ weights[weight_name] = to_numpy(weight_value)
276
+
252
277
  # Initiate graph nodes.
253
278
  if node.op in [CALL_METHOD, CALL_FUNCTION]:
254
279
  graph_node_type = FunctionalNode
@@ -300,6 +325,8 @@ def nodes_builder(model: GraphModule,
300
325
  weights=weights,
301
326
  layer_class=node_type,
302
327
  has_activation=node_has_activation,
328
+ reuse=reuse,
329
+ reuse_group=reuse_group,
303
330
  **kwargs)
304
331
 
305
332
  # Generate graph inputs list.
@@ -365,3 +392,28 @@ def edges_builder(model: GraphModule,
365
392
  Edge(fx_node_2_graph_node[node], fx_node_2_graph_node[out_node], src_index, dst_index))
366
393
 
367
394
  return edges
395
+
396
+
397
+ def create_reuse_group(target: Any, weights: Dict[str, Any]) -> str:
398
+ """
399
+ Combine target and weights to create a unique reuse group identifier.
400
+ We consider the weights as part of the group identifier because they are not part of
401
+ the module in functional layers, but if a functional layer is using the same weights multiple
402
+ times it is considered to be reused.
403
+
404
+ This function creates a unique string identifier for a reuse group by combining
405
+ the target (typically a layer or operation name) with the weights IDs.
406
+
407
+ Args:
408
+ target (Any): The target of the node, typically a string or callable representing
409
+ a layer or operation.
410
+ weights (Dict[str, Any]): A dictionary of weight names to weight values.
411
+ The values can be any type (typically tensors or arrays).
412
+
413
+ Returns:
414
+ str: A unique string identifier for the reuse group.
415
+ """
416
+ if not weights:
417
+ return str(target)
418
+ weight_ids = tuple(sorted(id(weight) for weight in weights.values()))
419
+ return f"{target}_{weight_ids}"
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  import time
16
18
  from typing import Callable, Any, Tuple, List, Union
17
19
 
@@ -179,8 +181,11 @@ if FOUND_TORCH and FOUND_TORCHVISION:
179
181
  # get the model device
180
182
  device = get_working_device()
181
183
 
184
+ # copy model for data generation
185
+ model_for_data_gen = copy.deepcopy(model)
186
+
182
187
  # get a static graph representation of the model using torch.fx
183
- fx_model = symbolic_trace(model)
188
+ fx_model = symbolic_trace(model_for_data_gen)
184
189
 
185
190
  # Get Data Generation functions and classes
186
191
  image_pipeline, normalization, bn_layer_weighting_fn, bn_alignment_loss_fn, output_loss_fn, \
@@ -208,23 +213,23 @@ if FOUND_TORCH and FOUND_TORCHVISION:
208
213
  scheduler = scheduler_get_fn(data_generation_config.n_iter)
209
214
 
210
215
  # Set the current model
211
- set_model(model)
216
+ set_model(model_for_data_gen)
212
217
 
213
218
  # Create an activation extractor object to extract activations from the model
214
219
  activation_extractor = PytorchActivationExtractor(
215
- model,
220
+ model_for_data_gen,
216
221
  fx_model,
217
222
  data_generation_config.bn_layer_types,
218
223
  data_generation_config.last_layer_types)
219
224
 
220
225
  # Create an orig_bn_stats_holder object to hold original BatchNorm statistics
221
- orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
226
+ orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model_for_data_gen, data_generation_config.bn_layer_types)
222
227
  if orig_bn_stats_holder.get_num_bn_layers() == 0:
223
228
  Logger.critical(
224
229
  f'Data generation requires a model with at least one BatchNorm layer.') # pragma: no cover
225
230
 
226
231
  # Create an ImagesOptimizationHandler object for handling optimization
227
- all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
232
+ all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model_for_data_gen,
228
233
  data_gen_batch_size=data_generation_config.data_gen_batch_size,
229
234
  init_dataset=init_dataset,
230
235
  optimizer=data_generation_config.optimizer,
@@ -207,7 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig,
207
207
  base_config=const_config_input16_per_tensor)
208
208
 
209
209
  qpreserving_const_config = const_config.clone_and_edit(enable_activation_quantization=False,
210
- quantization_preserving=True)
210
+ quantization_preserving=True,
211
+ default_weight_attr_config=const_config.default_weight_attr_config.clone_and_edit(
212
+ weights_per_channel_threshold=False))
211
213
  qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config])
212
214
 
213
215
  # Create a TargetPlatformModel and set its default quantization config.
@@ -19,7 +19,7 @@ import torch
19
19
  from torch import add, sub, mul, div, divide, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, \
20
20
  chunk, unbind, topk, gather, equal, transpose, permute, argmax, squeeze, multiply, subtract
21
21
  from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
22
- from torch.nn import Dropout, Flatten, Hardtanh, Identity
22
+ from torch.nn import Dropout, Flatten, Hardtanh
23
23
  from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
24
24
  from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu
25
25
 
@@ -87,7 +87,7 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
87
87
  squeeze,
88
88
  permute,
89
89
  transpose])
90
- tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather])
90
+ tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [gather, torch.Tensor.expand])
91
91
  tp.OperationsSetToLayers(OPSET_MERGE_OPS,
92
92
  [torch.stack, torch.cat, torch.concat, torch.concatenate])
93
93