mct-nightly 2.3.0.20250513.611__py3-none-any.whl → 2.3.0.20250515.544__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250513.611
3
+ Version: 2.3.0.20250515.544
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,6 +1,6 @@
1
- mct_nightly-2.3.0.20250513.611.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=p_G6GkwHl_GiPtc0E2qL6iUBG-UpYcgFx1HDi073s0Q,1557
3
- model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
1
+ mct_nightly-2.3.0.20250515.544.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=ZuiC7LBUZRbxQhR-vJI5NKeCIc9cX-tIpkHCw_Ynb0o,1557
3
+ model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
6
6
  model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
31
31
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
32
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
33
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
34
- model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
34
+ model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=S7hBbUJxL52Z8uJ9_upLdFyoSEJvgmVX0OmneqDIj-c,18656
35
35
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
37
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
@@ -446,7 +446,7 @@ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/
446
446
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=NCwuvnByeexLL987h67XhU8vQvCgq63bt0hFSiSSxvE,6400
447
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
448
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
449
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
449
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=1jkj0ZO3t9M0SRpe9ZcSucraSoB4raezIbpcO_lZcP4,10084
450
450
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities_component.py,sha256=9Hg6AMCzTdDsKKgivRd61UjxGT5SWvKsc3mIUPPsYDQ,1021
451
451
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/layer_filter_params.py,sha256=dIu6k1xvGKLtk_47wq1eKYvrS4lYAknAXTeJfFstW0Y,3878
452
452
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/operations_to_layers.py,sha256=vZ7I2XDr_YDgU8oQt8gKkcuUOJf28DCzCPunPK2h_Xw,6563
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250513.611.dist-info/METADATA,sha256=dx0fsYTzsB_Y1IVuSNMaJPgPO4lhotb3TlDZ-dq2JF8,25136
532
- mct_nightly-2.3.0.20250513.611.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
533
- mct_nightly-2.3.0.20250513.611.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250513.611.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250515.544.dist-info/METADATA,sha256=dV9aRBw1JVkuZDXyGl4aFtA91lLC_NtYTDquO5yA8rY,25136
532
+ mct_nightly-2.3.0.20250515.544.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
533
+ mct_nightly-2.3.0.20250515.544.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250515.544.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250513.000611"
30
+ __version__ = "2.3.0.20250515.000544"
@@ -138,3 +138,8 @@ SHAPE = 'shape'
138
138
  NODE_NAME = 'node_name'
139
139
  TOTAL_SIZE = 'total_size'
140
140
  NODE_OUTPUT_INDEX = 'node_output_index'
141
+
142
+
143
+ # Fusing Patterns constants
144
+ FUSED_LAYER_PATTERN = 'fused_layer_pattern'
145
+ FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
17
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
18
+ from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
17
19
  from dataclasses import dataclass, field
18
20
 
19
21
  from typing import Optional, List, Dict, Any, Tuple
@@ -41,6 +43,7 @@ class FusingInfo:
41
43
  fusing_patterns: any = None
42
44
  fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
43
45
  node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
46
+ fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
44
47
 
45
48
  def __post_init__(self):
46
49
  """Validates and initializes mappings after dataclass instantiation."""
@@ -49,6 +52,7 @@ class FusingInfo:
49
52
  assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
50
53
 
51
54
  self._init_node_mapping()
55
+ self._init_quantization_config_map()
52
56
 
53
57
  def _init_node_mapping(self) -> None:
54
58
  """
@@ -59,6 +63,15 @@ class FusingInfo:
59
63
  for node in nodes:
60
64
  self.node_to_fused_node_map[node.name] = op_id
61
65
 
66
+ def _init_quantization_config_map(self) -> None:
67
+ """
68
+ Init the mapping between fused operation IDs and their quantization configurations.
69
+ """
70
+ self.fused_op_id_to_quant_config.clear()
71
+ if self.fusing_patterns is not None:
72
+ for op_id, nodes in self.fusing_data.items():
73
+ self.set_fused_op_quantization_config(op_id, nodes)
74
+
62
75
  def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
63
76
  """
64
77
  Add a new fused operation with the given ID and set of nodes.
@@ -78,6 +91,22 @@ class FusingInfo:
78
91
  for node in nodes:
79
92
  self.node_to_fused_node_map[node.name] = op_id
80
93
 
94
+ # Update the quantization config mapping for this operation
95
+ if self.fusing_patterns is not None:
96
+ self.set_fused_op_quantization_config(op_id, nodes)
97
+
98
+ def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
99
+ """
100
+ Set the quantization configuration for a given fused operation ID.
101
+
102
+ Args:
103
+ op_id (str): The identifier for the fused operation.
104
+ nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
105
+ """
106
+ fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
107
+ if fusing_pattern is not None:
108
+ self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)
109
+
81
110
  def remove_fused_operation(self, op_id: str) -> None:
82
111
  """
83
112
  Remove a fused operation by its ID.
@@ -95,6 +124,7 @@ class FusingInfo:
95
124
  for node in nodes:
96
125
  self.node_to_fused_node_map.pop(node.name, None)
97
126
  del self.fusing_data[op_id]
127
+ self.fused_op_id_to_quant_config.pop(op_id, None)
98
128
 
99
129
  def get_fused_node_name(self, node_name: str) -> Optional[str]:
100
130
  """
@@ -117,6 +147,15 @@ class FusingInfo:
117
147
  """
118
148
  return self.node_to_fused_node_map.copy()
119
149
 
150
+ def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
151
+ """
152
+ Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
153
+
154
+ Returns:
155
+ A dictionary mapping each fused operation ID to its quantization configuration.
156
+ """
157
+ return self.fused_op_id_to_quant_config.copy()
158
+
120
159
  def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
121
160
  """
122
161
  Retrieve the list of nodes for a given fused operation ID.
@@ -129,6 +168,18 @@ class FusingInfo:
129
168
  """
130
169
  return self.fusing_data.get(op_id)
131
170
 
171
+ def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
172
+ """
173
+ Retrieve the quantization configuration for a given fused operation ID.
174
+
175
+ Args:
176
+ op_id (str): The identifier for the fused operation.
177
+
178
+ Returns:
179
+ OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
180
+ """
181
+ return self.fused_op_id_to_quant_config.get(op_id)
182
+
132
183
  def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
133
184
  """
134
185
  Check if a node is part of any fused operation.
@@ -216,10 +267,11 @@ class FusingInfo:
216
267
  all_fused_nodes.update(node_set)
217
268
 
218
269
  # Check 4: Ensure the sequence matches a valid fusing pattern
219
- if not is_valid_fusion(self.fusing_patterns, nodes):
270
+ valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
271
+ if not is_valid_fusion(valid_fusing_patterns, nodes):
220
272
  raise ValueError(
221
273
  f"Fused operation {op_id} does not match any valid fusing pattern "
222
- f"from {self.fusing_patterns}."
274
+ f"from {valid_fusing_patterns}."
223
275
  )
224
276
 
225
277
  def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
@@ -240,7 +292,8 @@ class FusingInfo:
240
292
  return False
241
293
 
242
294
  # Check if the provided nodes match a valid fusion pattern
243
- return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
295
+ valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
296
+ return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)
244
297
 
245
298
  def __repr__(self) -> str:
246
299
  """
@@ -287,8 +340,11 @@ class FusingInfoGenerator:
287
340
  if not self._fusing_patterns:
288
341
  return FusingInfo(fusing_patterns=self._fusing_patterns)
289
342
 
343
+ # Extract fusing layer patterns
344
+ fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
345
+
290
346
  # Find max fusion
291
- max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
347
+ max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])
292
348
 
293
349
  # Travel along the graph to find layers for fusing
294
350
  nodes = graph.get_topo_sorted_nodes()
@@ -302,9 +358,9 @@ class FusingInfoGenerator:
302
358
  continue
303
359
  # Start fusing search
304
360
  fusing_nodes = [] # nodes that are candidates for participating in fusing
305
- patterns = copy.deepcopy(self._fusing_patterns)
361
+ patterns = copy.deepcopy(fusing_layer_patterns)
306
362
  next_nodes = [node]
307
- for i in range(max_layers_fusing):
363
+ for i in range(max_layer_patterns):
308
364
  patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
309
365
  if len(patterns) == 0: # Give up if no more fusion pattern
310
366
  break
@@ -314,7 +370,7 @@ class FusingInfoGenerator:
314
370
  break
315
371
 
316
372
  # New fusion
317
- if is_valid_fusion(self._fusing_patterns, fusing_nodes):
373
+ if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
318
374
  fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
319
375
  assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
320
376
  fusing_info[fused_op_id] = tuple(fusing_nodes)
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
371
427
  if counter == fusion_depth:
372
428
  return True
373
429
  return False
430
+
431
+
432
+ def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
433
+ """
434
+ Extracts the fusing layer patterns from the provided fusing patterns.
435
+ Args:
436
+ fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
437
+
438
+ Returns:
439
+ supported fusing layer patterns
440
+ """
441
+ return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]
@@ -31,6 +31,9 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
31
31
  OpQuantizationConfig, QuantizationConfigOptions
32
32
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
33
33
 
34
+ from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
35
+
36
+
34
37
  class FrameworkQuantizationCapabilities(ImmutableClass):
35
38
  """
36
39
  Attach framework information to a modeled hardware.
@@ -94,20 +97,26 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
94
97
  """
95
98
  return self.op_sets_to_layers.get_layers_by_op(op)
96
99
 
97
- def get_fusing_patterns(self) -> List[List[Any]]:
100
+ def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
98
101
  """
99
102
 
100
- Returns: List of patterns of layers/LayerFilterParams to fuse.
103
+ Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
101
104
 
102
105
  """
103
- res = []
106
+
107
+ patterns = []
104
108
  if self.tpc.fusing_patterns is None:
105
- return res
109
+ return patterns
110
+
106
111
  for p in self.tpc.fusing_patterns:
112
+ res = []
107
113
  ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
108
114
  res.extend(itertools.product(*ops))
109
- return [list(x) for x in res]
110
115
 
116
+ fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
117
+ patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)
118
+
119
+ return patterns
111
120
 
112
121
  def get_info(self) -> Dict[str, Any]:
113
122
  """