mct-nightly 2.3.0.20250421.604__py3-none-any.whl → 2.3.0.20250423.537__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250421.604
3
+ Version: 2.3.0.20250423.537
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250421.604.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=wnmK1gqXxy1bGqYedwsyNhvn4OUAZel25ytzES08fmk,1557
1
+ mct_nightly-2.3.0.20250423.537.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=NXqhvuAEHQlzpdJpTtyy1rDJW2gxmMwEGZvHIBQE_f0,1557
3
3
  model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -34,7 +34,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
34
34
  model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
35
35
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=2aRpL8OP-JWKc2XFdsAQjACthJZmS8zgwIX-wjBRCFQ,41383
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
38
38
  model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
@@ -75,10 +75,10 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=cjFnpDvxZDE4K2sgt26DhosA2XqhxHDs0eW5Qe7AwAQ,40668
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=D2sNbTPMDsDyUE18NUpVJN27AgdwwhpdOJ8UMLmhdPA,40420
79
79
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=32s620FyREMBJYx3AUp6umlRfHxjqhL31PRbVtLdMJ4,6664
81
+ model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
82
82
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
83
83
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
84
84
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -112,7 +112,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
113
113
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
114
114
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
115
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=IeBy1kh3Rdp_LFEd0K2Jc_XANDPYJQDYP9MYrpTE29k,29550
115
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=_hhRb5eeFwbtPddu2xdLi7qK1RsxoR7UHUfjO0ICM3Q,30586
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
117
117
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=_m-XkEMJMHf0gYwVIXAoHVjdRa2NXt_gYdwBlw76ZR8,24031
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
@@ -439,7 +439,7 @@ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema
439
439
  model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
440
440
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
441
441
  model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
442
- model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=FiSkRUSuEPnJxvyDuRTwv2gwY4xveSp1hLtWKEFa8zc,6110
442
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=ncKPHVNyq_Yy_F7HOVDHT68EDfRnaB9yVnEP3C89GJk,10627
443
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
444
444
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
445
445
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250421.604.dist-info/METADATA,sha256=fRMmNKrtVjZdLZAaNnma2VUWCi47C4GsOq2HbD4Dyoc,25413
532
- mct_nightly-2.3.0.20250421.604.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
533
- mct_nightly-2.3.0.20250421.604.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250421.604.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250423.537.dist-info/METADATA,sha256=PeCeasbP-z1tGgn1g1tpCnmFWMlSQ7HtIfby4JXgX68,25413
532
+ mct_nightly-2.3.0.20250423.537.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
533
+ mct_nightly-2.3.0.20250423.537.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250423.537.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250421.000604"
30
+ __version__ = "2.3.0.20250423.000537"
@@ -754,7 +754,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
754
754
  """
755
755
  while node.is_quantization_preserving():
756
756
  prev_nodes = self.get_prev_nodes(node)
757
- assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
757
+ assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs."
758
758
  node = prev_nodes[0]
759
759
  return node
760
760
 
@@ -51,7 +51,6 @@ class BitwidthMode(Enum):
51
51
  single-precision nodes. To compute custom single precision configuration, use QCustom.
52
52
  """
53
53
  Float = auto()
54
- Q8Bit = auto()
55
54
  QMaxBit = auto()
56
55
  QMinBit = auto()
57
56
  QCustom = auto()
@@ -573,7 +572,7 @@ class ResourceUtilizationCalculator:
573
572
  not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
574
573
  return 0
575
574
 
576
- act_qc = act_qcs.get(a_node.name) if act_qcs else None
575
+ act_qc = self._extract_qc(a_node, act_qcs)
577
576
  a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
578
577
  w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
579
578
  node_bops = a_nbits * w_nbits * node_mac
@@ -708,23 +707,20 @@ class ResourceUtilizationCalculator:
708
707
  Returns:
709
708
  Activation bit-width.
710
709
  """
710
+ n = self.graph.retrieve_preserved_quantization_node(n)
711
711
  if act_qc:
712
712
  assert bitwidth_mode == BitwidthMode.QCustom
713
713
  return act_qc.activation_n_bits if act_qc.quant_mode == ActivationQuantizationMode.QUANT else FLOAT_BITWIDTH
714
714
 
715
- if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
716
- n.is_quantization_preserving()):
715
+ if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
717
716
  return FLOAT_BITWIDTH
718
717
 
719
- if bitwidth_mode == BitwidthMode.Q8Bit:
720
- return 8
721
-
722
718
  if bitwidth_mode in self._bitwidth_mode_fn:
723
719
  candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
724
720
  return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
725
721
 
726
722
  if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
727
- qcs = self.graph.retrieve_preserved_quantization_node(n).get_unique_activation_candidates()
723
+ qcs = n.get_unique_activation_candidates()
728
724
  if len(qcs) != 1:
729
725
  raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
730
726
  f'as it has {len(qcs)}!=1 unique candidates.')
@@ -760,9 +756,6 @@ class ResourceUtilizationCalculator:
760
756
  if bitwidth_mode == BitwidthMode.Float or not n.is_weights_quantization_enabled(w_attr):
761
757
  return FLOAT_BITWIDTH
762
758
 
763
- if bitwidth_mode == BitwidthMode.Q8Bit:
764
- return 8
765
-
766
759
  node_qcs = n.get_unique_weights_candidates(w_attr)
767
760
  w_qcs = [qc.weights_quantization_cfg.get_attr_config(w_attr) for qc in node_qcs]
768
761
  if bitwidth_mode in cls._bitwidth_mode_fn:
@@ -16,7 +16,7 @@ from collections import defaultdict
16
16
 
17
17
  import numpy as np
18
18
  from pulp import *
19
- from typing import Dict, Tuple, Any
19
+ from typing import Dict, Tuple, Any, List
20
20
 
21
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
22
22
 
@@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
67
67
  nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
68
68
  nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
69
69
 
70
- for n in graph.nodes:
70
+ for n in graph.get_topo_sorted_nodes():
71
71
  manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
72
72
  WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
73
73
  set_quantization_configs_to_node(node=n,
@@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode,
199
199
  if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
200
200
  not node.get_has_activation():
201
201
  candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
202
+ elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
203
+ prev_nodes = graph.get_prev_nodes(node)
204
+ if len(prev_nodes) != 1:
205
+ # Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
206
+ Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
207
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
208
+ elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
209
+ # Preserving the quantization of an unquantized node isn't possible, so disable it.
210
+ Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
211
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
202
212
 
203
213
 
204
214
  def create_node_activation_qc(qc: QuantizationConfig,
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
  import pprint
16
16
  from enum import Enum
17
- from typing import Dict, Any, Tuple, Optional
17
+ from typing import Dict, Any, Union, Tuple, Optional, Annotated
18
18
 
19
- from pydantic import BaseModel, root_validator, model_validator, ConfigDict
19
+ from pydantic import BaseModel, Field, root_validator, model_validator, ConfigDict
20
20
 
21
21
  from mct_quantizers import QuantizationMethod
22
22
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
@@ -29,8 +29,7 @@ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
29
29
  TargetPlatformModelComponent,
30
30
  OperatorsSetBase,
31
31
  OperatorsSet,
32
- OperatorSetGroup,
33
- Fusing)
32
+ OperatorSetGroup)
34
33
 
35
34
 
36
35
  class OperatorSetNames(str, Enum):
@@ -98,6 +97,109 @@ class OperatorSetNames(str, Enum):
98
97
  return [v.value for v in cls]
99
98
 
100
99
 
100
+ class Fusing(TargetPlatformModelComponent):
101
+ """
102
+ Fusing defines a tuple of operators that should be combined and treated as a single operator,
103
+ hence no quantization is applied between them.
104
+
105
+ Attributes:
106
+ operator_groups (Tuple[Union[OperatorsSet, OperatorSetGroup], ...]): A tuple of operator groups,
107
+ each being either an OperatorSetGroup or an OperatorsSet.
108
+ fuse_op_quantization_config (Optional[OpQuantizationConfig]): The quantization configuration for the fused operator.
109
+ name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
110
+ """
111
+ operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
112
+ fuse_op_quantization_config: Optional[OpQuantizationConfig] = None
113
+ name: Optional[str] = None # Will be set in the validator if not given.
114
+
115
+ model_config = ConfigDict(frozen=True)
116
+
117
+ @model_validator(mode="before")
118
+ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
119
+ """
120
+ Validate the operator_groups and set the name by concatenating operator group names.
121
+
122
+ Args:
123
+ values (Dict[str, Any]): Input data.
124
+
125
+ Returns:
126
+ Dict[str, Any]: Modified input data with 'name' set.
127
+ """
128
+ operator_groups = values.get('operator_groups')
129
+
130
+ # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
131
+ if isinstance(operator_groups, list):
132
+ values['operator_groups'] = tuple(operator_groups)
133
+
134
+ if values.get('name') is None:
135
+ # Generate the concatenated name from the operator groups
136
+ concatenated_name = "_".join([
137
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
138
+ for op in values['operator_groups']
139
+ ])
140
+ values['name'] = concatenated_name
141
+
142
+ return values
143
+
144
+ @model_validator(mode="after")
145
+ def validate_after_initialization(cls, model: 'Fusing') -> Any:
146
+ """
147
+ Perform validation after the model has been instantiated.
148
+ Ensures that there are at least two operator groups.
149
+ """
150
+ if len(model.operator_groups) < 2:
151
+ Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
152
+
153
+ return model
154
+
155
+ def contains(self, other: Any) -> bool:
156
+ """
157
+ Determines if the current Fusing instance contains another Fusing instance.
158
+
159
+ Args:
160
+ other (Any): The other Fusing instance to check against.
161
+
162
+ Returns:
163
+ bool: True if the other Fusing instance is contained within this one, False otherwise.
164
+ """
165
+ if not isinstance(other, Fusing):
166
+ return False
167
+
168
+ # Check for containment by comparing operator groups
169
+ for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
170
+ for j in range(len(other.operator_groups)):
171
+ if self.operator_groups[i + j] != other.operator_groups[j] and not (
172
+ isinstance(self.operator_groups[i + j], OperatorSetGroup) and (
173
+ other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
174
+ break
175
+ else:
176
+ # If all checks pass, the other Fusing instance is contained
177
+ return True
178
+ # Other Fusing instance is not contained
179
+ return False
180
+
181
+ def get_info(self) -> Union[Dict[str, str], str]:
182
+ """
183
+ Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
184
+
185
+ Returns:
186
+ Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key
187
+ and the sequence of operator groups as the value,
188
+ or just the sequence of operator groups if no name is set.
189
+ """
190
+ if self.name is not None:
191
+ return {
192
+ self.name: ' -> '.join([
193
+ x.name.value if isinstance(x.name, OperatorSetNames) else x.name
194
+ for x in self.operator_groups
195
+ ])
196
+ }
197
+ return ' -> '.join([
198
+ x.name.value if isinstance(x.name, OperatorSetNames) else x.name
199
+ for x in self.operator_groups
200
+ ])
201
+
202
+
101
203
  class TargetPlatformCapabilities(BaseModel):
102
204
  """
103
205
  Represents the hardware configuration used for quantized model inference.