mct-nightly 2.3.0.20250513.611__py3-none-any.whl → 2.3.0.20250515.544__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/RECORD +9 -9
- {mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +5 -0
- model_compression_toolkit/core/common/fusion/fusing_info.py +75 -7
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +14 -5
- {mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
3
|
-
model_compression_toolkit/constants.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250515.544.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=ZuiC7LBUZRbxQhR-vJI5NKeCIc9cX-tIpkHCw_Ynb0o,1557
|
3
|
+
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
6
6
|
model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
|
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
|
|
31
31
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
32
32
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
33
33
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
34
|
-
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=
|
34
|
+
model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=S7hBbUJxL52Z8uJ9_upLdFyoSEJvgmVX0OmneqDIj-c,18656
|
35
35
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
37
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
|
@@ -446,7 +446,7 @@ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/
|
|
446
446
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=NCwuvnByeexLL987h67XhU8vQvCgq63bt0hFSiSSxvE,6400
|
447
447
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
448
448
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
|
449
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=
|
449
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=1jkj0ZO3t9M0SRpe9ZcSucraSoB4raezIbpcO_lZcP4,10084
|
450
450
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities_component.py,sha256=9Hg6AMCzTdDsKKgivRd61UjxGT5SWvKsc3mIUPPsYDQ,1021
|
451
451
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/layer_filter_params.py,sha256=dIu6k1xvGKLtk_47wq1eKYvrS4lYAknAXTeJfFstW0Y,3878
|
452
452
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/operations_to_layers.py,sha256=vZ7I2XDr_YDgU8oQt8gKkcuUOJf28DCzCPunPK2h_Xw,6563
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250515.544.dist-info/METADATA,sha256=dV9aRBw1JVkuZDXyGl4aFtA91lLC_NtYTDquO5yA8rY,25136
|
532
|
+
mct_nightly-2.3.0.20250515.544.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
|
533
|
+
mct_nightly-2.3.0.20250515.544.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250515.544.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250515.000544"
|
@@ -138,3 +138,8 @@ SHAPE = 'shape'
|
|
138
138
|
NODE_NAME = 'node_name'
|
139
139
|
TOTAL_SIZE = 'total_size'
|
140
140
|
NODE_OUTPUT_INDEX = 'node_output_index'
|
141
|
+
|
142
|
+
|
143
|
+
# Fusing Patterns constants
|
144
|
+
FUSED_LAYER_PATTERN = 'fused_layer_pattern'
|
145
|
+
FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
|
18
|
+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
|
17
19
|
from dataclasses import dataclass, field
|
18
20
|
|
19
21
|
from typing import Optional, List, Dict, Any, Tuple
|
@@ -41,6 +43,7 @@ class FusingInfo:
|
|
41
43
|
fusing_patterns: any = None
|
42
44
|
fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
|
43
45
|
node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
|
46
|
+
fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
|
44
47
|
|
45
48
|
def __post_init__(self):
|
46
49
|
"""Validates and initializes mappings after dataclass instantiation."""
|
@@ -49,6 +52,7 @@ class FusingInfo:
|
|
49
52
|
assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
|
50
53
|
|
51
54
|
self._init_node_mapping()
|
55
|
+
self._init_quantization_config_map()
|
52
56
|
|
53
57
|
def _init_node_mapping(self) -> None:
|
54
58
|
"""
|
@@ -59,6 +63,15 @@ class FusingInfo:
|
|
59
63
|
for node in nodes:
|
60
64
|
self.node_to_fused_node_map[node.name] = op_id
|
61
65
|
|
66
|
+
def _init_quantization_config_map(self) -> None:
|
67
|
+
"""
|
68
|
+
Init the mapping between fused operation IDs and their quantization configurations.
|
69
|
+
"""
|
70
|
+
self.fused_op_id_to_quant_config.clear()
|
71
|
+
if self.fusing_patterns is not None:
|
72
|
+
for op_id, nodes in self.fusing_data.items():
|
73
|
+
self.set_fused_op_quantization_config(op_id, nodes)
|
74
|
+
|
62
75
|
def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
|
63
76
|
"""
|
64
77
|
Add a new fused operation with the given ID and set of nodes.
|
@@ -78,6 +91,22 @@ class FusingInfo:
|
|
78
91
|
for node in nodes:
|
79
92
|
self.node_to_fused_node_map[node.name] = op_id
|
80
93
|
|
94
|
+
# Update the quantization config mapping for this operation
|
95
|
+
if self.fusing_patterns is not None:
|
96
|
+
self.set_fused_op_quantization_config(op_id, nodes)
|
97
|
+
|
98
|
+
def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
|
99
|
+
"""
|
100
|
+
Set the quantization configuration for a given fused operation ID.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
op_id (str): The identifier for the fused operation.
|
104
|
+
nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
|
105
|
+
"""
|
106
|
+
fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
|
107
|
+
if fusing_pattern is not None:
|
108
|
+
self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)
|
109
|
+
|
81
110
|
def remove_fused_operation(self, op_id: str) -> None:
|
82
111
|
"""
|
83
112
|
Remove a fused operation by its ID.
|
@@ -95,6 +124,7 @@ class FusingInfo:
|
|
95
124
|
for node in nodes:
|
96
125
|
self.node_to_fused_node_map.pop(node.name, None)
|
97
126
|
del self.fusing_data[op_id]
|
127
|
+
self.fused_op_id_to_quant_config.pop(op_id, None)
|
98
128
|
|
99
129
|
def get_fused_node_name(self, node_name: str) -> Optional[str]:
|
100
130
|
"""
|
@@ -117,6 +147,15 @@ class FusingInfo:
|
|
117
147
|
"""
|
118
148
|
return self.node_to_fused_node_map.copy()
|
119
149
|
|
150
|
+
def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
|
151
|
+
"""
|
152
|
+
Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
A dictionary mapping each fused operation ID to its quantization configuration.
|
156
|
+
"""
|
157
|
+
return self.fused_op_id_to_quant_config.copy()
|
158
|
+
|
120
159
|
def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
|
121
160
|
"""
|
122
161
|
Retrieve the list of nodes for a given fused operation ID.
|
@@ -129,6 +168,18 @@ class FusingInfo:
|
|
129
168
|
"""
|
130
169
|
return self.fusing_data.get(op_id)
|
131
170
|
|
171
|
+
def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
|
172
|
+
"""
|
173
|
+
Retrieve the quantization configuration for a given fused operation ID.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
op_id (str): The identifier for the fused operation.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
|
180
|
+
"""
|
181
|
+
return self.fused_op_id_to_quant_config.get(op_id)
|
182
|
+
|
132
183
|
def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
|
133
184
|
"""
|
134
185
|
Check if a node is part of any fused operation.
|
@@ -216,10 +267,11 @@ class FusingInfo:
|
|
216
267
|
all_fused_nodes.update(node_set)
|
217
268
|
|
218
269
|
# Check 4: Ensure the sequence matches a valid fusing pattern
|
219
|
-
|
270
|
+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
|
271
|
+
if not is_valid_fusion(valid_fusing_patterns, nodes):
|
220
272
|
raise ValueError(
|
221
273
|
f"Fused operation {op_id} does not match any valid fusing pattern "
|
222
|
-
f"from {
|
274
|
+
f"from {valid_fusing_patterns}."
|
223
275
|
)
|
224
276
|
|
225
277
|
def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
|
@@ -240,7 +292,8 @@ class FusingInfo:
|
|
240
292
|
return False
|
241
293
|
|
242
294
|
# Check if the provided nodes match a valid fusion pattern
|
243
|
-
|
295
|
+
valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
|
296
|
+
return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)
|
244
297
|
|
245
298
|
def __repr__(self) -> str:
|
246
299
|
"""
|
@@ -287,8 +340,11 @@ class FusingInfoGenerator:
|
|
287
340
|
if not self._fusing_patterns:
|
288
341
|
return FusingInfo(fusing_patterns=self._fusing_patterns)
|
289
342
|
|
343
|
+
# Extract fusing layer patterns
|
344
|
+
fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
|
345
|
+
|
290
346
|
# Find max fusion
|
291
|
-
|
347
|
+
max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])
|
292
348
|
|
293
349
|
# Travel along the graph to find layers for fusing
|
294
350
|
nodes = graph.get_topo_sorted_nodes()
|
@@ -302,9 +358,9 @@ class FusingInfoGenerator:
|
|
302
358
|
continue
|
303
359
|
# Start fusing search
|
304
360
|
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
305
|
-
patterns = copy.deepcopy(
|
361
|
+
patterns = copy.deepcopy(fusing_layer_patterns)
|
306
362
|
next_nodes = [node]
|
307
|
-
for i in range(
|
363
|
+
for i in range(max_layer_patterns):
|
308
364
|
patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
|
309
365
|
if len(patterns) == 0: # Give up if no more fusion pattern
|
310
366
|
break
|
@@ -314,7 +370,7 @@ class FusingInfoGenerator:
|
|
314
370
|
break
|
315
371
|
|
316
372
|
# New fusion
|
317
|
-
if is_valid_fusion(
|
373
|
+
if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
|
318
374
|
fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
|
319
375
|
assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
|
320
376
|
fusing_info[fused_op_id] = tuple(fusing_nodes)
|
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
|
|
371
427
|
if counter == fusion_depth:
|
372
428
|
return True
|
373
429
|
return False
|
430
|
+
|
431
|
+
|
432
|
+
def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
|
433
|
+
"""
|
434
|
+
Extracts the fusing layer patterns from the provided fusing patterns.
|
435
|
+
Args:
|
436
|
+
fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
|
437
|
+
|
438
|
+
Returns:
|
439
|
+
supported fusing layer patterns
|
440
|
+
"""
|
441
|
+
return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]
|
@@ -31,6 +31,9 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
|
|
31
31
|
OpQuantizationConfig, QuantizationConfigOptions
|
32
32
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
|
33
33
|
|
34
|
+
from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
|
35
|
+
|
36
|
+
|
34
37
|
class FrameworkQuantizationCapabilities(ImmutableClass):
|
35
38
|
"""
|
36
39
|
Attach framework information to a modeled hardware.
|
@@ -94,20 +97,26 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
|
|
94
97
|
"""
|
95
98
|
return self.op_sets_to_layers.get_layers_by_op(op)
|
96
99
|
|
97
|
-
def get_fusing_patterns(self) -> List[List[Any]]:
|
100
|
+
def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
|
98
101
|
"""
|
99
102
|
|
100
|
-
Returns: List of patterns of layers/LayerFilterParams to fuse.
|
103
|
+
Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
|
101
104
|
|
102
105
|
"""
|
103
|
-
|
106
|
+
|
107
|
+
patterns = []
|
104
108
|
if self.tpc.fusing_patterns is None:
|
105
|
-
return
|
109
|
+
return patterns
|
110
|
+
|
106
111
|
for p in self.tpc.fusing_patterns:
|
112
|
+
res = []
|
107
113
|
ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
|
108
114
|
res.extend(itertools.product(*ops))
|
109
|
-
return [list(x) for x in res]
|
110
115
|
|
116
|
+
fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
|
117
|
+
patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)
|
118
|
+
|
119
|
+
return patterns
|
111
120
|
|
112
121
|
def get_info(self) -> Dict[str, Any]:
|
113
122
|
"""
|
File without changes
|
{mct_nightly-2.3.0.20250513.611.dist-info → mct_nightly-2.3.0.20250515.544.dist-info}/top_level.txt
RENAMED
File without changes
|