mct-nightly 2.3.0.20250331.610__py3-none-any.whl → 2.3.0.20250402.536__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/RECORD +14 -13
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +17 -0
- model_compression_toolkit/core/common/graph/base_node.py +13 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +35 -12
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v2.py +177 -0
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +1 -0
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +2 -1
- {mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250402.536
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
{mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250402.536.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=dhPx1u7eKO_zAY9CefOYP31YislX9FLOXxMFWv9PVJo,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -34,8 +34,8 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
34
34
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza4vpWlmmqgvkpUmWVdfdx0nEIB0p2n8,6195
|
35
35
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
38
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=cSwHUqwZEiR1t2DaBfc7_qSJbtX8crpqerN4ol9v3H8,38859
|
38
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=CJu8_r80MGVnYmlAUGOGKGRsD9xShMyaRNb3VMeRC0s,34523
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
41
41
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
@@ -75,10 +75,10 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
|
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=
|
78
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=MP4Q5lThvEIhfa1iBajQQM3nCUNgK-2yseqQQ8Rgiog,40624
|
79
79
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
|
-
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=
|
81
|
+
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=TaK5NqVdmygsHw9_x5JsJ-BPvlbKA9cRyTno1R8gbnU,7269
|
82
82
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
83
83
|
model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
|
84
84
|
model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
|
@@ -435,13 +435,14 @@ model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgD
|
|
435
435
|
model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
|
436
436
|
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
|
437
437
|
model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
438
|
-
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=
|
438
|
+
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
|
439
439
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
|
440
440
|
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpWENuOyjwaIMaGrFI0Act7jsSeT7m94pjrv91dxE,27516
|
441
|
+
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=vUhCocA0EcjdR741Yv48W4Kr5Pq22Miebhm7F9GKb3Y,6086
|
441
442
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
442
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
443
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=
|
444
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256
|
444
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=AE09QLE_QKwNqUTZbkZP9XLJStG1ECiTWmEGuXZTEsQ,7652
|
445
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=-zbPmzQJal-1vZiQ6vIBBBnlEOB2DTb09koA0Aj4I_I,6396
|
445
446
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
446
447
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
|
447
448
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
|
@@ -526,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
526
527
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
527
528
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
528
529
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
529
|
-
mct_nightly-2.3.0.
|
530
|
-
mct_nightly-2.3.0.
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
530
|
+
mct_nightly-2.3.0.20250402.536.dist-info/METADATA,sha256=v7bHr6SXyb9qkbAmhb3xAHDXzinP-56xYnw-3SMtHVc,27098
|
531
|
+
mct_nightly-2.3.0.20250402.536.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
532
|
+
mct_nightly-2.3.0.20250402.536.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
533
|
+
mct_nightly-2.3.0.20250402.536.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250402.000536"
|
@@ -696,6 +696,23 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
696
696
|
sorted_conf_activation = self.get_sorted_activation_configurable_nodes()
|
697
697
|
return [(n, n.final_activation_quantization_cfg.activation_n_bits) for n in sorted_conf_activation]
|
698
698
|
|
699
|
+
def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
|
700
|
+
"""
|
701
|
+
For a node with quantization_preserving == True, get the previous non-quantization_preserving node
|
702
|
+
to get activation quantization config from. If quantization_preserving is False return node.
|
703
|
+
Args:
|
704
|
+
node: quantization preserving node.
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
The node that the quantization preserving node should get the activation quantization from.
|
708
|
+
|
709
|
+
"""
|
710
|
+
while node.is_quantization_preserving():
|
711
|
+
prev_nodes = self.get_prev_nodes(node)
|
712
|
+
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
|
713
|
+
node = prev_nodes[0]
|
714
|
+
return node
|
715
|
+
|
699
716
|
def update_fused_nodes(self, fusion: List[Any]):
|
700
717
|
"""
|
701
718
|
Updates the graphs fusions list with a new list of nodes that have been fused.
|
@@ -131,6 +131,19 @@ class BaseNode:
|
|
131
131
|
qc.activation_quantization_cfg.enable_activation_quantization
|
132
132
|
return self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization
|
133
133
|
|
134
|
+
def is_quantization_preserving(self) -> bool:
|
135
|
+
"""
|
136
|
+
Returns: Whether node activation quantization information is preserved from its inputs.
|
137
|
+
"""
|
138
|
+
if self.final_activation_quantization_cfg:
|
139
|
+
# if we have a final configuration, then we only care to check if it enables activation quantization.
|
140
|
+
return self.final_activation_quantization_cfg.quantization_preserving
|
141
|
+
|
142
|
+
for qc in self.candidates_quantization_cfg:
|
143
|
+
assert self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving == \
|
144
|
+
qc.activation_quantization_cfg.quantization_preserving
|
145
|
+
return self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving
|
146
|
+
|
134
147
|
def is_weights_quantization_enabled(self, attr_name: str) -> bool:
|
135
148
|
"""
|
136
149
|
Checks whether a node's weights attribute quantization is enabled.
|
@@ -335,13 +335,35 @@ class ResourceUtilizationCalculator:
|
|
335
335
|
"""
|
336
336
|
return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
|
337
337
|
|
338
|
+
def _extract_qc(self, n: BaseNode, act_qcs: Optional[ActivationQCfgPerNode] = None
|
339
|
+
) -> Union[NodeActivationQuantizationConfig, None]:
|
340
|
+
"""
|
341
|
+
Extract quantization config the activation configs dictionary is provided. If node is quantization
|
342
|
+
preserving, extract the quantization config from the preceding activation quantized node (i.e.
|
343
|
+
the Quantization the original node preserves).
|
344
|
+
|
345
|
+
Args:
|
346
|
+
n: Node to extract qc for.
|
347
|
+
act_qcs: custom activations quantization configuration. If not provided, the default
|
348
|
+
configuration will be extracted from the node.
|
349
|
+
|
350
|
+
Returns:
|
351
|
+
The relevant quantization config.
|
352
|
+
"""
|
353
|
+
if act_qcs:
|
354
|
+
assert not (n.is_quantization_preserving() and act_qcs.get(n.name) is not None), \
|
355
|
+
f"Quantization preserving node {n.name} should not have a qc for this computation."
|
356
|
+
return act_qcs.get(self.graph.retrieve_preserved_quantization_node(n).name)
|
357
|
+
return None
|
358
|
+
|
338
359
|
def compute_activation_utilization_by_cut(self,
|
339
360
|
target_criterion: TargetInclusionCriterion,
|
340
361
|
bitwidth_mode: BitwidthMode,
|
341
362
|
act_qcs: Optional[ActivationQCfgPerNode] = None) \
|
342
363
|
-> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
|
343
364
|
"""
|
344
|
-
Compute graph activation cuts utilization.
|
365
|
+
Compute graph activation cuts utilization. If activation quantization configs are provided, then for
|
366
|
+
quantization preserving nodes, get the previous quantized activation node bit-width.
|
345
367
|
|
346
368
|
Args:
|
347
369
|
target_criterion: criterion to include weights for computation.
|
@@ -369,7 +391,7 @@ class ResourceUtilizationCalculator:
|
|
369
391
|
if not cut_target_nodes:
|
370
392
|
continue
|
371
393
|
for n in cut_target_nodes:
|
372
|
-
qc =
|
394
|
+
qc = self._extract_qc(n, act_qcs)
|
373
395
|
util_per_cut_per_node[cut][n.name] = self.compute_node_activation_tensor_utilization(n, target_criterion,
|
374
396
|
bitwidth_mode, qc)
|
375
397
|
util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore
|
@@ -384,7 +406,8 @@ class ResourceUtilizationCalculator:
|
|
384
406
|
include_reused=False) \
|
385
407
|
-> Tuple[float, Dict[NodeName, Utilization]]:
|
386
408
|
"""
|
387
|
-
Compute resource utilization for graph's activations tensors.
|
409
|
+
Compute resource utilization for graph's activations tensors. If activation quantization configs are provided, then for
|
410
|
+
quantization preserving nodes, get the previous quantized activation node bit-width.
|
388
411
|
|
389
412
|
Args:
|
390
413
|
target_criterion: criterion to include weights for computation.
|
@@ -405,7 +428,7 @@ class ResourceUtilizationCalculator:
|
|
405
428
|
|
406
429
|
util_per_node: Dict[NodeName, Utilization] = {}
|
407
430
|
for n in self._topo_sort(nodes):
|
408
|
-
qc =
|
431
|
+
qc = self._extract_qc(n, act_qcs)
|
409
432
|
util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
|
410
433
|
util_per_node[n.name] = util
|
411
434
|
|
@@ -659,7 +682,7 @@ class ResourceUtilizationCalculator:
|
|
659
682
|
if target_criterion == TargetInclusionCriterion.QConfigurable:
|
660
683
|
nodes = [n for n in nodes if n.has_configurable_activation()]
|
661
684
|
elif target_criterion == TargetInclusionCriterion.AnyQuantized:
|
662
|
-
nodes = [n for n in nodes if n.is_activation_quantization_enabled()]
|
685
|
+
nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
|
663
686
|
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
664
687
|
nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
|
665
688
|
elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
|
@@ -668,8 +691,7 @@ class ResourceUtilizationCalculator:
|
|
668
691
|
nodes = [n for n in nodes if not n.reuse]
|
669
692
|
return nodes
|
670
693
|
|
671
|
-
|
672
|
-
def _get_activation_nbits(cls,
|
694
|
+
def _get_activation_nbits(self,
|
673
695
|
n: BaseNode,
|
674
696
|
bitwidth_mode: BitwidthMode,
|
675
697
|
act_qc: Optional[NodeActivationQuantizationConfig]) -> int:
|
@@ -690,21 +712,22 @@ class ResourceUtilizationCalculator:
|
|
690
712
|
assert bitwidth_mode == BitwidthMode.QCustom
|
691
713
|
return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH
|
692
714
|
|
693
|
-
if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled()
|
715
|
+
if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
|
716
|
+
n.is_quantization_preserving()):
|
694
717
|
return FLOAT_BITWIDTH
|
695
718
|
|
696
719
|
if bitwidth_mode == BitwidthMode.Q8Bit:
|
697
720
|
return 8
|
698
721
|
|
699
|
-
if bitwidth_mode in
|
722
|
+
if bitwidth_mode in self._bitwidth_mode_fn:
|
700
723
|
candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
|
701
|
-
return
|
724
|
+
return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
|
702
725
|
|
703
726
|
if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
|
704
|
-
qcs = n.get_unique_activation_candidates()
|
727
|
+
qcs = self.graph.retrieve_preserved_quantization_node(n).get_unique_activation_candidates()
|
705
728
|
if len(qcs) != 1:
|
706
729
|
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
|
707
|
-
f'as it has {len(qcs)}!=1 unique candidates
|
730
|
+
f'as it has {len(qcs)}!=1 unique candidates.')
|
708
731
|
return qcs[0].activation_quantization_cfg.activation_n_bits
|
709
732
|
|
710
733
|
raise ValueError(f'Unknown mode {bitwidth_mode}') # pragma: no cover
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
16
|
from pulp import *
|
17
|
-
from typing import Dict, Tuple
|
17
|
+
from typing import Dict, Tuple, List
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
|
20
20
|
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import pprint
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Dict, Any, Tuple, Optional
|
18
|
+
|
19
|
+
from pydantic import BaseModel, root_validator
|
20
|
+
|
21
|
+
from mct_quantizers import QuantizationMethod
|
22
|
+
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
23
|
+
from model_compression_toolkit.logger import Logger
|
24
|
+
from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
|
25
|
+
Signedness,
|
26
|
+
AttributeQuantizationConfig,
|
27
|
+
OpQuantizationConfig,
|
28
|
+
QuantizationConfigOptions,
|
29
|
+
TargetPlatformModelComponent,
|
30
|
+
OperatorsSetBase,
|
31
|
+
OperatorsSet,
|
32
|
+
OperatorSetGroup,
|
33
|
+
Fusing)
|
34
|
+
|
35
|
+
|
36
|
+
class OperatorSetNames(str, Enum):
|
37
|
+
CONV = "Conv"
|
38
|
+
DEPTHWISE_CONV = "DepthwiseConv2D"
|
39
|
+
CONV_TRANSPOSE = "ConvTranspose"
|
40
|
+
FULLY_CONNECTED = "FullyConnected"
|
41
|
+
CONCATENATE = "Concatenate"
|
42
|
+
STACK = "Stack"
|
43
|
+
UNSTACK = "Unstack"
|
44
|
+
GATHER = "Gather"
|
45
|
+
EXPAND = "Expend"
|
46
|
+
BATCH_NORM = "BatchNorm"
|
47
|
+
L2NORM = "L2Norm"
|
48
|
+
RELU = "ReLU"
|
49
|
+
RELU6 = "ReLU6"
|
50
|
+
LEAKY_RELU = "LeakyReLU"
|
51
|
+
ELU = "Elu"
|
52
|
+
HARD_TANH = "HardTanh"
|
53
|
+
ADD = "Add"
|
54
|
+
SUB = "Sub"
|
55
|
+
MUL = "Mul"
|
56
|
+
DIV = "Div"
|
57
|
+
MIN = "Min"
|
58
|
+
MAX = "Max"
|
59
|
+
PRELU = "PReLU"
|
60
|
+
ADD_BIAS = "AddBias"
|
61
|
+
SWISH = "Swish"
|
62
|
+
SIGMOID = "Sigmoid"
|
63
|
+
SOFTMAX = "Softmax"
|
64
|
+
LOG_SOFTMAX = "LogSoftmax"
|
65
|
+
TANH = "Tanh"
|
66
|
+
GELU = "Gelu"
|
67
|
+
HARDSIGMOID = "HardSigmoid"
|
68
|
+
HARDSWISH = "HardSwish"
|
69
|
+
FLATTEN = "Flatten"
|
70
|
+
GET_ITEM = "GetItem"
|
71
|
+
RESHAPE = "Reshape"
|
72
|
+
UNSQUEEZE = "Unsqueeze"
|
73
|
+
SQUEEZE = "Squeeze"
|
74
|
+
PERMUTE = "Permute"
|
75
|
+
TRANSPOSE = "Transpose"
|
76
|
+
DROPOUT = "Dropout"
|
77
|
+
SPLIT_CHUNK = "SplitChunk"
|
78
|
+
MAXPOOL = "MaxPool"
|
79
|
+
AVGPOOL = "AvgPool"
|
80
|
+
SIZE = "Size"
|
81
|
+
SHAPE = "Shape"
|
82
|
+
EQUAL = "Equal"
|
83
|
+
ARGMAX = "ArgMax"
|
84
|
+
TOPK = "TopK"
|
85
|
+
FAKE_QUANT = "FakeQuant"
|
86
|
+
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
|
87
|
+
BOX_DECODE = "BoxDecode"
|
88
|
+
ZERO_PADDING2D = "ZeroPadding2D"
|
89
|
+
CAST = "Cast"
|
90
|
+
RESIZE = "Resize"
|
91
|
+
PAD = "Pad"
|
92
|
+
FOLD = "Fold"
|
93
|
+
STRIDED_SLICE = "StridedSlice"
|
94
|
+
SSD_POST_PROCESS = "SSDPostProcess"
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def get_values(cls):
|
98
|
+
return [v.value for v in cls]
|
99
|
+
|
100
|
+
|
101
|
+
class TargetPlatformCapabilities(BaseModel):
|
102
|
+
"""
|
103
|
+
Represents the hardware configuration used for quantized model inference.
|
104
|
+
|
105
|
+
Attributes:
|
106
|
+
default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
|
107
|
+
operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
|
108
|
+
fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
|
109
|
+
tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
|
110
|
+
tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
|
111
|
+
tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
|
112
|
+
add_metadata (bool): Flag to determine if metadata should be added.
|
113
|
+
name (str): Name of the Target Platform Model.
|
114
|
+
is_simd_padding (bool): Indicates if SIMD padding is applied.
|
115
|
+
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
116
|
+
"""
|
117
|
+
default_qco: QuantizationConfigOptions
|
118
|
+
operator_set: Optional[Tuple[OperatorsSet, ...]]
|
119
|
+
fusing_patterns: Optional[Tuple[Fusing, ...]]
|
120
|
+
tpc_minor_version: Optional[int]
|
121
|
+
tpc_patch_version: Optional[int]
|
122
|
+
tpc_platform_type: Optional[str]
|
123
|
+
add_metadata: bool = True
|
124
|
+
name: Optional[str] = "default_tpc"
|
125
|
+
is_simd_padding: bool = False
|
126
|
+
|
127
|
+
SCHEMA_VERSION: int = 2
|
128
|
+
|
129
|
+
class Config:
|
130
|
+
frozen = True
|
131
|
+
|
132
|
+
@root_validator(allow_reuse=True)
|
133
|
+
def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
134
|
+
"""
|
135
|
+
Perform validation after the model has been instantiated.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
values (Dict[str, Any]): The instantiated target platform model.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Dict[str, Any]: The validated values.
|
142
|
+
"""
|
143
|
+
# Validate `default_qco`
|
144
|
+
default_qco = values.get('default_qco')
|
145
|
+
if len(default_qco.quantization_configurations) != 1:
|
146
|
+
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
147
|
+
|
148
|
+
# Validate `operator_set` uniqueness
|
149
|
+
operator_set = values.get('operator_set')
|
150
|
+
if operator_set is not None:
|
151
|
+
opsets_names = [
|
152
|
+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
153
|
+
for op in operator_set
|
154
|
+
]
|
155
|
+
if len(set(opsets_names)) != len(opsets_names):
|
156
|
+
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
157
|
+
|
158
|
+
return values
|
159
|
+
|
160
|
+
def get_info(self) -> Dict[str, Any]:
|
161
|
+
"""
|
162
|
+
Get a dictionary summarizing the TargetPlatformCapabilities properties.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
|
166
|
+
"""
|
167
|
+
return {
|
168
|
+
"Model name": self.name,
|
169
|
+
"Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
|
170
|
+
"Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
|
171
|
+
}
|
172
|
+
|
173
|
+
def show(self):
|
174
|
+
"""
|
175
|
+
Display the TargetPlatformCapabilities.
|
176
|
+
"""
|
177
|
+
pprint.pprint(self.get_info(), sort_dicts=False)
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py
CHANGED
@@ -93,6 +93,7 @@ class AttachTpcToKeras(AttachTpcToFramework):
|
|
93
93
|
OperatorSetNames.TOPK: [tf.nn.top_k],
|
94
94
|
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
|
95
95
|
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
|
96
|
+
OperatorSetNames.BOX_DECODE: [], # no such operator in keras
|
96
97
|
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
|
97
98
|
OperatorSetNames.CAST: [tf.cast],
|
98
99
|
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py
CHANGED
@@ -97,7 +97,8 @@ class AttachTpcToPytorch(AttachTpcToFramework):
|
|
97
97
|
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
|
98
98
|
Eq('p', 2) | Eq('p', None))],
|
99
99
|
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
100
|
-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
|
100
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
|
101
|
+
OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
|
101
102
|
}
|
102
103
|
|
103
104
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250331.610.dist-info → mct_nightly-2.3.0.20250402.536.dist-info}/top_level.txt
RENAMED
File without changes
|