mct-nightly 2.3.0.20250127.521__py3-none-any.whl → 2.3.0.20250129.508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/METADATA +1 -2
- {mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/RECORD +12 -12
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +109 -92
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +12 -1
- model_compression_toolkit/core/runner.py +4 -7
- model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +3 -5
- model_compression_toolkit/logger.py +0 -2
- {mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250129.508
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
@@ -13,7 +13,6 @@ Requires-Dist: networkx!=2.8.1
|
|
13
13
|
Requires-Dist: tqdm
|
14
14
|
Requires-Dist: Pillow
|
15
15
|
Requires-Dist: numpy<2.0
|
16
|
-
Requires-Dist: opencv-python
|
17
16
|
Requires-Dist: scikit-image
|
18
17
|
Requires-Dist: scikit-learn
|
19
18
|
Requires-Dist: tensorboard
|
{mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/RECORD
RENAMED
@@ -1,14 +1,14 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=5yoBYX4rzT_uWukrdVc6UJm2mlcKu_6gh6agaWOa4-s,1557
|
2
2
|
model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
|
-
model_compression_toolkit/logger.py,sha256=
|
4
|
+
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
5
5
|
model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
|
6
6
|
model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7ryw4tvhFDe6rCUk,1405
|
7
7
|
model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
|
8
8
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
9
9
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=CVTjBaci8F6EP3IKDnRMfxkP-Sv8qY8GpkGt6FyII2U,11376
|
10
10
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=OtL6g2rTC5mfdKrkzm47EPPW-voGGVYMYxpy2_sfu1U,6547
|
11
|
-
model_compression_toolkit/core/runner.py,sha256=
|
11
|
+
model_compression_toolkit/core/runner.py,sha256=iJpDasfs7wtdAelIRaBPxDbN64phPern1O86QDM2HeY,13706
|
12
12
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
13
13
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
14
14
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=IkMydCj6voau7dwkYLYA_Ka_EFUKP3GKQdpYN6b1fgc,22163
|
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza
|
|
33
33
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
|
34
34
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
35
35
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=WDyN45Y_wdBR3d5nb-3AX2tsrPxeUtc6GE98xZA-0mY,37818
|
36
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
36
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=_SJBlDIwq5Kt2HLYWIT6POJFnUfrtcOFlOLxTbadJ1w,33058
|
37
37
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
38
38
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
39
39
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
@@ -73,7 +73,7 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
|
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=8oAFJc_KC3z5ClI-zo4KC40kKGscyixUc5oYP4j4cMo,8019
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=T5yVr7lay-6QLuTDBZNI1Ufj02EMBWuY_yHjC8eHx5I,3998
|
76
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=
|
76
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=DyiE84ECgwtaCATWcisv-7ndmBUbj_TaddZ7GeIjlrU,35307
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=J7gqUGs4ITo4ufl84A5vACxm670LG6RhQyXkejfpbn8,8834
|
78
78
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
79
79
|
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=uhC0az5OVSfeYexcasoy0cT8ZOonFKIedk_1U-ZPLhA,17171
|
@@ -104,7 +104,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
104
104
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
105
105
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
|
106
106
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=IHVX-Gdekru4xLuDTgcsp_JCnRtuVWnbYsDBQuSXTKc,7079
|
107
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
107
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=teDclY8WmuVqqa9Fgr6WY-7ILDep0QKzKxoZCKzBG2k,26960
|
108
108
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
|
109
109
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=HfBkSiRTOf9mNF-TNQHTCCs3xSg66F20no0O6vl5v1Y,2154
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
|
@@ -312,7 +312,7 @@ model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha
|
|
312
312
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
313
313
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
|
314
314
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=We0fVMQ4oU7Y0IWQ8fKy8KpqkIiLyKoQeF9XKAQ6TH0,3317
|
315
|
-
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=
|
315
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=0mV2BuegNvL9MnDBu2NiJo--4KCcdDDzbWUMU4uld5w,4678
|
316
316
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py,sha256=NydGxFIclmrfU3HWYUrRbprg4hPt470QP6MTOMLEhRs,9172
|
317
317
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py,sha256=PRVmn8o2hTdwTdbd2ezf__LNbFvcgiVO0c25dsyg3Tg,6549
|
318
318
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py,sha256=zMjY2y4FSHonuY5hddbMTb8qAQtLtohYF7q1wuruDDs,3267
|
@@ -523,8 +523,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
523
523
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
524
524
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
525
525
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
526
|
-
mct_nightly-2.3.0.
|
527
|
-
mct_nightly-2.3.0.
|
528
|
-
mct_nightly-2.3.0.
|
529
|
-
mct_nightly-2.3.0.
|
530
|
-
mct_nightly-2.3.0.
|
526
|
+
mct_nightly-2.3.0.20250129.508.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
527
|
+
mct_nightly-2.3.0.20250129.508.dist-info/METADATA,sha256=V1ZMks36vbn2kcBLkb88KKI-viLM_xXXWWIPogCNTnI,26572
|
528
|
+
mct_nightly-2.3.0.20250129.508.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
529
|
+
mct_nightly-2.3.0.20250129.508.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
530
|
+
mct_nightly-2.3.0.20250129.508.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250129.000508"
|
@@ -30,6 +30,9 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
30
30
|
FrameworkQuantizationCapabilities
|
31
31
|
|
32
32
|
|
33
|
+
WeightAttrT = Union[str, int]
|
34
|
+
|
35
|
+
|
33
36
|
class BaseNode:
|
34
37
|
"""
|
35
38
|
Class to represent a node in a graph that represents the model.
|
@@ -40,7 +43,7 @@ class BaseNode:
|
|
40
43
|
framework_attr: Dict[str, Any],
|
41
44
|
input_shape: Tuple[Any],
|
42
45
|
output_shape: Tuple[Any],
|
43
|
-
weights: Dict[
|
46
|
+
weights: Dict[WeightAttrT, np.ndarray],
|
44
47
|
layer_class: type,
|
45
48
|
reuse: bool = False,
|
46
49
|
reuse_group: str = None,
|
@@ -189,7 +192,7 @@ class BaseNode:
|
|
189
192
|
"""
|
190
193
|
return self.reuse or self.reuse_group is not None
|
191
194
|
|
192
|
-
def _get_weight_name(self, name:
|
195
|
+
def _get_weight_name(self, name: WeightAttrT) -> List[WeightAttrT]:
|
193
196
|
"""
|
194
197
|
Get weight names that match argument name (either string weights or integer for
|
195
198
|
positional weights).
|
@@ -203,7 +206,7 @@ class BaseNode:
|
|
203
206
|
return [k for k in self.weights.keys()
|
204
207
|
if (isinstance(k, int) and name == k) or (isinstance(k, str) and name in k)]
|
205
208
|
|
206
|
-
def get_weights_by_keys(self, name:
|
209
|
+
def get_weights_by_keys(self, name: WeightAttrT) -> np.ndarray:
|
207
210
|
"""
|
208
211
|
Get a node's weight by its name.
|
209
212
|
Args:
|
@@ -15,13 +15,14 @@
|
|
15
15
|
from collections import defaultdict
|
16
16
|
from copy import deepcopy
|
17
17
|
from enum import Enum, auto
|
18
|
-
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
|
18
|
+
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
|
19
19
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
21
21
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
22
22
|
from model_compression_toolkit.core import FrameworkInfo
|
23
23
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
24
24
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
25
|
+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
25
26
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
|
26
27
|
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut
|
27
28
|
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
|
@@ -79,24 +80,25 @@ class Utilization(NamedTuple):
|
|
79
80
|
bytes: memory utilization.
|
80
81
|
"""
|
81
82
|
size: int
|
82
|
-
bytes:
|
83
|
+
bytes: float
|
83
84
|
|
84
85
|
def __add__(self, other: 'Utilization') -> 'Utilization':
|
86
|
+
""" Add another Utilization object. """
|
85
87
|
return Utilization(self.size + other.size, self.bytes + other.bytes)
|
86
88
|
|
87
|
-
def __radd__(self, other:
|
88
|
-
|
89
|
-
if other
|
90
|
-
|
91
|
-
return self
|
89
|
+
def __radd__(self, other: Literal[0]):
|
90
|
+
""" Right add is only supported with 0 to allow the sum operator (with the default start_value=0) """
|
91
|
+
if other != 0:
|
92
|
+
raise ValueError('radd is only supported with 0')
|
93
|
+
return self
|
92
94
|
|
93
95
|
def __gt__(self, other: 'Utilization'):
|
94
|
-
|
96
|
+
""" Greater than operator by bytes. Needed for max. """
|
95
97
|
return self.bytes > other.bytes
|
96
98
|
|
97
99
|
def __lt__(self, other: 'Utilization'):
|
98
|
-
|
99
|
-
return self.bytes < other.bytes
|
100
|
+
""" Less than operator by bytes. Needed for min. """
|
101
|
+
return self.bytes < other.bytes
|
100
102
|
|
101
103
|
|
102
104
|
class ResourceUtilizationCalculator:
|
@@ -107,6 +109,8 @@ class ResourceUtilizationCalculator:
|
|
107
109
|
BitwidthMode.QMinBit: min,
|
108
110
|
}
|
109
111
|
|
112
|
+
unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
|
113
|
+
|
110
114
|
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
|
111
115
|
self.graph = graph
|
112
116
|
self.fw_impl = fw_impl
|
@@ -118,17 +122,17 @@ class ResourceUtilizationCalculator:
|
|
118
122
|
self._params_cnt = {}
|
119
123
|
for n in graph.nodes:
|
120
124
|
self._act_tensors_size[n] = n.get_total_output_params()
|
121
|
-
|
125
|
+
if n.weights:
|
126
|
+
self._params_cnt[n] = {k: v.size for k, v in n.weights.items()}
|
122
127
|
self._cuts: Optional[Dict[Cut, List[BaseNode]]] = None
|
123
128
|
|
124
129
|
@property
|
125
130
|
def cuts(self) -> Dict[Cut, List[BaseNode]]:
|
126
131
|
""" Compute if needed and return graph cuts and their memory element nodes. """
|
127
132
|
if self._cuts is None:
|
128
|
-
|
129
|
-
_, _, cuts = compute_graph_max_cut(memory_graph)
|
133
|
+
cuts = self._compute_cuts()
|
130
134
|
if cuts is None: # pragma: no cover
|
131
|
-
raise RuntimeError("Failed to calculate activation memory cuts for graph.")
|
135
|
+
raise RuntimeError("Failed to calculate activation memory cuts for graph.")
|
132
136
|
cuts = [cut for cut in cuts if cut.mem_elements.elements]
|
133
137
|
# cache cuts nodes for future use, so do not filter by target
|
134
138
|
self._cuts = {cut: [self.graph.find_node_by_name(m.node_name)[0] for m in cut.mem_elements.elements]
|
@@ -140,7 +144,8 @@ class ResourceUtilizationCalculator:
|
|
140
144
|
bitwidth_mode: BitwidthMode,
|
141
145
|
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
|
142
146
|
w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None,
|
143
|
-
ru_targets: Iterable[RUTarget] = None
|
147
|
+
ru_targets: Iterable[RUTarget] = None,
|
148
|
+
allow_unused_qcs: bool = False) -> ResourceUtilization:
|
144
149
|
"""
|
145
150
|
Compute network's resource utilization.
|
146
151
|
|
@@ -154,16 +159,26 @@ class ResourceUtilizationCalculator:
|
|
154
159
|
In custom mode, must provide configuration for all configurable weights. For non-configurable
|
155
160
|
weights, if not provided, the default configuration will be extracted from the node.
|
156
161
|
ru_targets: metrics to include for computation. If None, all metrics are calculated.
|
162
|
+
allow_unused_qcs: by default, if custom quantization configs are passed, but are not going to be used for
|
163
|
+
any of the requested targets, an error is raised. To disable the validation, pass True.
|
157
164
|
|
158
165
|
Returns:
|
159
166
|
Resource utilization object.
|
160
167
|
"""
|
161
168
|
ru_targets = set(ru_targets) if ru_targets else set(RUTarget)
|
162
169
|
|
163
|
-
if w_qcs
|
164
|
-
raise ValueError(
|
165
|
-
|
166
|
-
|
170
|
+
if (w_qcs or act_qcs) and bitwidth_mode != BitwidthMode.QCustom:
|
171
|
+
raise ValueError(self.unexpected_qc_error)
|
172
|
+
|
173
|
+
if w_qcs and not {RUTarget.WEIGHTS, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets):
|
174
|
+
if not allow_unused_qcs:
|
175
|
+
raise ValueError('Weight configuration passed but no relevant ru_targets requested.')
|
176
|
+
w_qcs = None
|
177
|
+
|
178
|
+
if act_qcs and not {RUTarget.ACTIVATION, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets):
|
179
|
+
if not allow_unused_qcs:
|
180
|
+
raise ValueError('Activation configuration passed but no relevant ru_targets requested.')
|
181
|
+
act_qcs = None
|
167
182
|
|
168
183
|
w_total, a_total = None, None
|
169
184
|
if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets):
|
@@ -180,8 +195,7 @@ class ResourceUtilizationCalculator:
|
|
180
195
|
if RUTarget.TOTAL in ru_targets:
|
181
196
|
ru.total_memory = w_total + a_total
|
182
197
|
if RUTarget.BOPS in ru_targets:
|
183
|
-
ru.bops, _ = self.compute_bops(target_criterion=
|
184
|
-
bitwidth_mode=bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)
|
198
|
+
ru.bops, _ = self.compute_bops(target_criterion, bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)
|
185
199
|
|
186
200
|
assert ru.get_restricted_targets() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
|
187
201
|
return ru
|
@@ -206,35 +220,35 @@ class ResourceUtilizationCalculator:
|
|
206
220
|
- Per node total weights utilization. Dict keys are nodes in a topological order.
|
207
221
|
- Detailed per node per weight attribute utilization. Dict keys are nodes in a topological order.
|
208
222
|
"""
|
209
|
-
|
210
|
-
|
211
|
-
|
223
|
+
if w_qcs and bitwidth_mode != BitwidthMode.QCustom:
|
224
|
+
raise ValueError(self.unexpected_qc_error)
|
225
|
+
|
226
|
+
node_attrs = self._collect_target_nodes_w_attrs(target_criterion, include_reused=False)
|
212
227
|
|
213
228
|
util_per_node: Dict[BaseNode, Utilization] = {}
|
214
229
|
util_per_node_per_weight = {}
|
215
|
-
|
216
|
-
for n in self._topo_sort(nodes):
|
230
|
+
for n in self._topo_sort(list(node_attrs.keys())):
|
217
231
|
w_qc = w_qcs.get(n) if w_qcs else None
|
218
|
-
node_weights_util, per_weight_util = self.compute_node_weights_utilization(n,
|
232
|
+
node_weights_util, per_weight_util = self.compute_node_weights_utilization(n, node_attrs[n],
|
219
233
|
bitwidth_mode, w_qc)
|
220
234
|
util_per_node[n] = node_weights_util
|
221
235
|
util_per_node_per_weight[n] = per_weight_util
|
222
236
|
|
223
|
-
total_util = sum(util_per_node.values())
|
237
|
+
total_util = sum(util_per_node.values()) if util_per_node else Utilization(0, 0)
|
224
238
|
return total_util.bytes, util_per_node, util_per_node_per_weight
|
225
239
|
|
226
240
|
def compute_node_weights_utilization(self,
|
227
241
|
n: BaseNode,
|
228
|
-
target_criterion: TargetInclusionCriterion,
|
242
|
+
target_criterion: Union[TargetInclusionCriterion, List[str]],
|
229
243
|
bitwidth_mode: BitwidthMode,
|
230
|
-
qc: NodeWeightsQuantizationConfig)\
|
244
|
+
qc: Optional[NodeWeightsQuantizationConfig] = None)\
|
231
245
|
-> Tuple[Utilization, Dict[str, Utilization]]:
|
232
246
|
"""
|
233
247
|
Compute resource utilization for weights of a node.
|
234
248
|
|
235
249
|
Args:
|
236
250
|
n: node.
|
237
|
-
target_criterion: criterion to include weights for computation.
|
251
|
+
target_criterion: criterion to include weights for computation, or explicit attributes list (full names).
|
238
252
|
bitwidth_mode: bit-width mode for the computation.
|
239
253
|
qc: custom weights quantization configuration. Should be provided for custom bit mode only.
|
240
254
|
In custom mode, must provide configuration for all configurable weights. For non-configurable
|
@@ -244,9 +258,21 @@ class ResourceUtilizationCalculator:
|
|
244
258
|
- Node's total weights utilization.
|
245
259
|
- Detailed per weight attribute utilization.
|
246
260
|
"""
|
247
|
-
|
248
|
-
|
249
|
-
|
261
|
+
if qc:
|
262
|
+
if bitwidth_mode != BitwidthMode.QCustom:
|
263
|
+
raise ValueError(self.unexpected_qc_error)
|
264
|
+
if set(qc.all_weight_attrs) - set(n.get_node_weights_attributes()):
|
265
|
+
raise ValueError(f'Custom configuration contains unexpected weight attrs {qc.all_weight_attrs} for '
|
266
|
+
f'node {n} containing weight attrs {n.get_node_weights_attributes()}.')
|
267
|
+
|
268
|
+
# If target criterion is passed, weights_attrs may return empty, that's fine.
|
269
|
+
# However, if an explicit list is passed, it must be non-empty.
|
270
|
+
if isinstance(target_criterion, TargetInclusionCriterion):
|
271
|
+
weight_attrs = self._get_target_weight_attrs(n, target_criterion)
|
272
|
+
else:
|
273
|
+
weight_attrs = target_criterion
|
274
|
+
if not weight_attrs:
|
275
|
+
raise ValueError('Explicit list of attributes to compute cannot be empty.')
|
250
276
|
|
251
277
|
attr_util = {}
|
252
278
|
for attr in weight_attrs:
|
@@ -255,7 +281,7 @@ class ResourceUtilizationCalculator:
|
|
255
281
|
bytes_ = size * nbits / 8
|
256
282
|
attr_util[attr] = Utilization(size, bytes_)
|
257
283
|
|
258
|
-
total_weights: Utilization = sum(attr_util.values())
|
284
|
+
total_weights: Utilization = sum(attr_util.values()) if attr_util else Utilization(0, 0)
|
259
285
|
return total_weights, attr_util
|
260
286
|
|
261
287
|
def compute_activations_utilization(self,
|
@@ -280,7 +306,7 @@ class ResourceUtilizationCalculator:
|
|
280
306
|
def compute_activation_utilization_by_cut(self,
|
281
307
|
target_criterion: TargetInclusionCriterion,
|
282
308
|
bitwidth_mode: BitwidthMode,
|
283
|
-
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \
|
309
|
+
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None) \
|
284
310
|
-> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
|
285
311
|
"""
|
286
312
|
Compute graph activation cuts utilization.
|
@@ -297,15 +323,15 @@ class ResourceUtilizationCalculator:
|
|
297
323
|
- Total activation utilization per cut.
|
298
324
|
- Detailed activation utilization per cut per node.
|
299
325
|
"""
|
300
|
-
if
|
301
|
-
raise
|
326
|
+
if act_qcs and not bitwidth_mode == BitwidthMode.QCustom:
|
327
|
+
raise ValueError(self.unexpected_qc_error)
|
302
328
|
|
303
329
|
graph_target_nodes = self._get_target_activation_nodes(target_criterion, include_reused=True)
|
304
330
|
# if there are no target activations in the graph, don't waste time looking for cuts
|
305
331
|
if not graph_target_nodes:
|
306
332
|
return 0, {}, {}
|
307
333
|
|
308
|
-
util_per_cut: Dict[Cut, Utilization] = {}
|
334
|
+
util_per_cut: Dict[Cut, Utilization] = {}
|
309
335
|
util_per_cut_per_node = defaultdict(dict)
|
310
336
|
for cut in self.cuts:
|
311
337
|
cut_target_nodes = self._get_cut_target_nodes(cut, target_criterion)
|
@@ -325,7 +351,7 @@ class ResourceUtilizationCalculator:
|
|
325
351
|
bitwidth_mode: BitwidthMode,
|
326
352
|
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
|
327
353
|
include_reused=False) \
|
328
|
-
-> Tuple[float, Dict[BaseNode, Utilization]]:
|
354
|
+
-> Tuple[float, Dict[BaseNode, Utilization]]:
|
329
355
|
"""
|
330
356
|
Compute resource utilization for graph's activations tensors.
|
331
357
|
|
@@ -341,9 +367,10 @@ class ResourceUtilizationCalculator:
|
|
341
367
|
- Detailed utilization per node. Dict keys are nodes in a topological order.
|
342
368
|
|
343
369
|
"""
|
370
|
+
if act_qcs and bitwidth_mode != BitwidthMode.QCustom:
|
371
|
+
raise ValueError(self.unexpected_qc_error)
|
372
|
+
|
344
373
|
nodes = self._get_target_activation_nodes(target_criterion, include_reused=include_reused)
|
345
|
-
if not nodes:
|
346
|
-
return 0, {}
|
347
374
|
|
348
375
|
util_per_node: Dict[BaseNode, Utilization] = {}
|
349
376
|
for n in self._topo_sort(nodes):
|
@@ -351,14 +378,14 @@ class ResourceUtilizationCalculator:
|
|
351
378
|
util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
|
352
379
|
util_per_node[n] = util
|
353
380
|
|
354
|
-
total_util = max(util_per_node.values())
|
355
|
-
return total_util
|
381
|
+
total_util = max(util_per_node.values()).bytes if util_per_node else 0
|
382
|
+
return total_util, util_per_node
|
356
383
|
|
357
384
|
def compute_node_activation_tensor_utilization(self,
|
358
385
|
n: BaseNode,
|
359
386
|
target_criterion: Optional[TargetInclusionCriterion],
|
360
387
|
bitwidth_mode: BitwidthMode,
|
361
|
-
qc: Optional[NodeActivationQuantizationConfig]) -> Utilization:
|
388
|
+
qc: Optional[NodeActivationQuantizationConfig] = None) -> Utilization:
|
362
389
|
"""
|
363
390
|
Compute activation resource utilization for a node.
|
364
391
|
|
@@ -372,9 +399,13 @@ class ResourceUtilizationCalculator:
|
|
372
399
|
Returns:
|
373
400
|
Node's activation utilization.
|
374
401
|
"""
|
402
|
+
if qc and bitwidth_mode != BitwidthMode.QCustom:
|
403
|
+
raise ValueError(self.unexpected_qc_error)
|
404
|
+
|
375
405
|
if target_criterion:
|
406
|
+
# only check whether the node meets the criterion
|
376
407
|
nodes = self._get_target_activation_nodes(target_criterion=target_criterion, include_reused=True, nodes=[n])
|
377
|
-
if not nodes:
|
408
|
+
if not nodes:
|
378
409
|
return Utilization(0, 0)
|
379
410
|
|
380
411
|
size = self._act_tensors_size[n]
|
@@ -410,7 +441,7 @@ class ResourceUtilizationCalculator:
|
|
410
441
|
if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover
|
411
442
|
raise NotImplementedError('BOPS computation is currently only supported for quantized targets.')
|
412
443
|
|
413
|
-
nodes = self.
|
444
|
+
nodes = self._collect_target_nodes_w_attrs(target_criterion, include_reused=True)
|
414
445
|
# filter out nodes with only positional weights # TODO add as arg to get target nodes
|
415
446
|
nodes = [n for n in nodes if n.has_kernel_weight_to_quantize(self.fw_info)]
|
416
447
|
|
@@ -448,7 +479,7 @@ class ResourceUtilizationCalculator:
|
|
448
479
|
|
449
480
|
incoming_edges = self.graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX)
|
450
481
|
# TODO temporary adding this for const_representation test in torch which has Linear with const input
|
451
|
-
if not incoming_edges:
|
482
|
+
if not incoming_edges: # pragma: no cover
|
452
483
|
return 0
|
453
484
|
assert len(incoming_edges) == 1, \
|
454
485
|
f'Unexpected number of inputs {len(incoming_edges)} for BOPS calculation. Expected 1.'
|
@@ -465,13 +496,11 @@ class ResourceUtilizationCalculator:
|
|
465
496
|
node_bops = a_nbits * w_nbits * node_mac
|
466
497
|
return node_bops
|
467
498
|
|
468
|
-
def
|
469
|
-
"""
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
""" Whether custom configuration for activations is compatible with the requested targets."""
|
474
|
-
return bool({RUTarget.ACTIVATION, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets))
|
499
|
+
def _compute_cuts(self):
|
500
|
+
""" Compute activation cuts of the graph. """
|
501
|
+
memory_graph = MemoryGraph(deepcopy(self.graph))
|
502
|
+
_, _, cuts = compute_graph_max_cut(memory_graph)
|
503
|
+
return cuts
|
475
504
|
|
476
505
|
def _get_cut_target_nodes(self, cut: Cut, target_criterion: TargetInclusionCriterion) -> List[BaseNode]:
|
477
506
|
"""
|
@@ -487,37 +516,23 @@ class ResourceUtilizationCalculator:
|
|
487
516
|
cut_nodes = self.cuts[cut]
|
488
517
|
return self._get_target_activation_nodes(target_criterion, include_reused=True, nodes=cut_nodes)
|
489
518
|
|
490
|
-
def
|
491
|
-
|
492
|
-
|
519
|
+
def _collect_target_nodes_w_attrs(self,
|
520
|
+
target_criterion: TargetInclusionCriterion,
|
521
|
+
include_reused: bool) -> Dict[BaseNode, List[WeightAttrT]]:
|
493
522
|
"""
|
494
|
-
Collect nodes to include in weights utilization computation.
|
523
|
+
Collect nodes and their weight attributes to include in weights utilization computation.
|
495
524
|
|
496
525
|
Args:
|
497
526
|
target_criterion: criterion to include weights for computation.
|
498
527
|
include_reused: whether to include reused nodes.
|
499
528
|
|
500
529
|
Returns:
|
501
|
-
|
530
|
+
A mapping from nodes to their weights attributes.
|
502
531
|
"""
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
508
|
-
# TODO this is wrong. Need to look at specific weights and not the whole node (if w1 is configurable and w2
|
509
|
-
# is non-configurable we want to discover the node both as configurable and non-configurable)
|
510
|
-
quantized = [n for n in self.graph if n.has_any_weight_attr_to_quantize()]
|
511
|
-
configurable = self.graph.get_weights_configurable_nodes(self.fw_info, include_reused_nodes=include_reused)
|
512
|
-
nodes = [n for n in quantized if n not in configurable]
|
513
|
-
elif target_criterion == TargetInclusionCriterion.Any: # pragma: no cover
|
514
|
-
nodes = list(self.graph.nodes)
|
515
|
-
else: # pragma: no cover
|
516
|
-
raise ValueError(f'Unknown {target_criterion}.')
|
517
|
-
|
518
|
-
if not include_reused:
|
519
|
-
nodes = [n for n in nodes if not n.reuse]
|
520
|
-
return nodes
|
532
|
+
nodes_attrs = {n: attrs for n in self.graph.nodes
|
533
|
+
if (attrs := self._get_target_weight_attrs(n, target_criterion))
|
534
|
+
and (include_reused or not n.reuse)}
|
535
|
+
return nodes_attrs
|
521
536
|
|
522
537
|
def _get_target_weight_attrs(self, n: BaseNode, target_criterion: TargetInclusionCriterion) -> List[str]:
|
523
538
|
"""
|
@@ -530,6 +545,7 @@ class ResourceUtilizationCalculator:
|
|
530
545
|
Returns:
|
531
546
|
Selected weight attributes names.
|
532
547
|
"""
|
548
|
+
# weight_attrs are the full names in the layer, e.g. 'conv2d_1/kernel:0' (or an integer for positional attrs)
|
533
549
|
weight_attrs = n.get_node_weights_attributes()
|
534
550
|
if target_criterion == TargetInclusionCriterion.QConfigurable:
|
535
551
|
weight_attrs = [attr for attr in weight_attrs if n.is_configurable_weight(attr)]
|
@@ -548,14 +564,17 @@ class ResourceUtilizationCalculator:
|
|
548
564
|
Sort nodes in a topological order (based on graph's nodes).
|
549
565
|
|
550
566
|
Args:
|
551
|
-
nodes: nodes to sort.
|
567
|
+
nodes: nodes to sort. Allowed to be empty.
|
552
568
|
|
553
569
|
Returns:
|
554
570
|
Nodes in topological order.
|
555
571
|
"""
|
572
|
+
if not nodes:
|
573
|
+
return list(nodes)
|
574
|
+
|
556
575
|
graph_topo_nodes = self.graph.get_topo_sorted_nodes()
|
557
576
|
topo_nodes = [n for n in graph_topo_nodes if n in nodes]
|
558
|
-
if len(topo_nodes) != len(nodes):
|
577
|
+
if len(topo_nodes) != len(nodes):
|
559
578
|
missing_nodes = [n for n in nodes if n not in topo_nodes]
|
560
579
|
raise ValueError(f'Could not topo-sort, nodes {missing_nodes} do not match the graph nodes.')
|
561
580
|
return topo_nodes
|
@@ -576,15 +595,15 @@ class ResourceUtilizationCalculator:
|
|
576
595
|
Selected nodes.
|
577
596
|
"""
|
578
597
|
nodes = nodes or self.graph.nodes
|
579
|
-
if target_criterion == TargetInclusionCriterion.QConfigurable:
|
598
|
+
if target_criterion == TargetInclusionCriterion.QConfigurable:
|
580
599
|
nodes = [n for n in nodes if n.has_configurable_activation()]
|
581
600
|
elif target_criterion == TargetInclusionCriterion.AnyQuantized:
|
582
601
|
nodes = [n for n in nodes if n.is_activation_quantization_enabled()]
|
583
|
-
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
602
|
+
elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
|
584
603
|
nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
|
585
604
|
elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
|
586
605
|
raise ValueError(f'Unknown {target_criterion}.')
|
587
|
-
if not include_reused:
|
606
|
+
if not include_reused:
|
588
607
|
nodes = [n for n in nodes if not n.reuse]
|
589
608
|
return nodes
|
590
609
|
|
@@ -607,8 +626,7 @@ class ResourceUtilizationCalculator:
|
|
607
626
|
Activation bit-width.
|
608
627
|
"""
|
609
628
|
if act_qc:
|
610
|
-
|
611
|
-
raise ValueError(f'Activation config is not expected for non-custom bit mode {bitwidth_mode}')
|
629
|
+
assert bitwidth_mode == BitwidthMode.QCustom
|
612
630
|
return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH
|
613
631
|
|
614
632
|
if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
|
@@ -623,8 +641,8 @@ class ResourceUtilizationCalculator:
|
|
623
641
|
|
624
642
|
if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
|
625
643
|
qcs = n.get_unique_activation_candidates()
|
626
|
-
if len(qcs) != 1:
|
627
|
-
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n
|
644
|
+
if len(qcs) != 1:
|
645
|
+
raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
|
628
646
|
f'as it has {len(qcs)}!=1 unique candidates .')
|
629
647
|
return qcs[0].activation_quantization_cfg.activation_n_bits
|
630
648
|
|
@@ -650,9 +668,8 @@ class ResourceUtilizationCalculator:
|
|
650
668
|
Returns:
|
651
669
|
Weight bit-width.
|
652
670
|
"""
|
671
|
+
assert not (w_qc and bitwidth_mode != BitwidthMode.QCustom)
|
653
672
|
if w_qc and w_qc.has_attribute_config(w_attr):
|
654
|
-
if bitwidth_mode != BitwidthMode.QCustom: # pragma: no cover
|
655
|
-
raise ValueError('Weight config is not expected for non-custom bit mode {bitwidth_mode}')
|
656
673
|
attr_cfg = w_qc.get_attr_config(w_attr)
|
657
674
|
return attr_cfg.weights_n_bits if attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH
|
658
675
|
|
@@ -669,9 +686,9 @@ class ResourceUtilizationCalculator:
|
|
669
686
|
|
670
687
|
if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
|
671
688
|
# if configuration was not passed and the weight has only one candidate, use it
|
672
|
-
if len(w_qcs) != 1:
|
673
|
-
raise ValueError(f'Could not retrieve the quantization candidate for attr {w_attr} of node {n
|
674
|
-
f'as it {len(w_qcs)}!=1 unique candidates.')
|
689
|
+
if len(w_qcs) != 1:
|
690
|
+
raise ValueError(f'Could not retrieve the quantization candidate for attr {w_attr} of node {n} '
|
691
|
+
f'as it has {len(w_qcs)}!=1 unique candidates.')
|
675
692
|
return w_qcs[0].weights_n_bits
|
676
693
|
|
677
694
|
raise ValueError(f'Unknown mode {bitwidth_mode.name}') # pragma: no cover
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from typing import Callable, Any, List, Tuple, Union, Dict
|
17
|
+
from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
@@ -28,6 +28,8 @@ from model_compression_toolkit.core.common.quantization.quantization_config impo
|
|
28
28
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
|
29
29
|
OpQuantizationConfig
|
30
30
|
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
|
31
33
|
|
32
34
|
##########################################
|
33
35
|
# Every node holds a quantization configuration
|
@@ -482,6 +484,15 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
482
484
|
|
483
485
|
return False
|
484
486
|
|
487
|
+
@property
|
488
|
+
def all_weight_attrs(self) -> List['WeightAttrT']:
|
489
|
+
""" Fetch all weight attributes keys (positional and named).
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
List of attributes.
|
493
|
+
"""
|
494
|
+
return list(self.pos_attributes_config_mapping.keys()) + list(self.attributes_config_mapping.keys())
|
495
|
+
|
485
496
|
def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]:
|
486
497
|
"""
|
487
498
|
Extract the saved attributes that contain the given attribute name.
|
@@ -229,14 +229,11 @@ def _set_final_resource_utilization(graph: Graph,
|
|
229
229
|
final_ru = None
|
230
230
|
if ru_targets:
|
231
231
|
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
232
|
-
w_qcs
|
233
|
-
|
234
|
-
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
|
235
|
-
if ru_calculator.is_custom_activation_config_applicable(ru_targets):
|
236
|
-
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
|
232
|
+
w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
|
233
|
+
a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
|
237
234
|
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
|
238
|
-
BitwidthMode.QCustom,
|
239
|
-
|
235
|
+
BitwidthMode.QCustom, act_qcs=a_qcs, w_qcs=w_qcs,
|
236
|
+
ru_targets=ru_targets, allow_unused_qcs=True)
|
240
237
|
summary = final_ru.get_summary_str(restricted=True)
|
241
238
|
Logger.info(f'Resource utilization for quantized mixed-precision targets:\n {summary}.')
|
242
239
|
graph.user_info.final_resource_utilization = final_ru
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py
CHANGED
@@ -15,11 +15,10 @@
|
|
15
15
|
from functools import partial
|
16
16
|
from typing import Tuple, Union, List, Callable, Dict
|
17
17
|
|
18
|
-
import cv2
|
19
18
|
from torch import Tensor
|
20
19
|
from torchvision.transforms.transforms import _setup_size
|
21
20
|
import torch
|
22
|
-
import
|
21
|
+
import torch.nn.functional as F
|
23
22
|
from torch.utils.data import Dataset, DataLoader
|
24
23
|
|
25
24
|
from model_compression_toolkit.data_generation.common.enums import DataInitType
|
@@ -97,9 +96,8 @@ def diverse_sample(size: Tuple[int, ...]) -> Tensor:
|
|
97
96
|
sample = random_std * torch.randn(size) + random_mean
|
98
97
|
|
99
98
|
# filtering to make the image a bit smoother
|
100
|
-
kernel =
|
101
|
-
|
102
|
-
sample = torch.from_numpy(cv2.filter2D(sample.float().detach().cpu().numpy(), -1, kernel))
|
99
|
+
kernel = torch.ones(NUM_INPUT_CHANNELS, NUM_INPUT_CHANNELS, 5, 5) / 16
|
100
|
+
sample = F.conv2d(sample, kernel, padding=1)
|
103
101
|
return sample.float()
|
104
102
|
|
105
103
|
def default_data_init_fn(
|
@@ -136,7 +136,6 @@ class Logger:
|
|
136
136
|
msg: Message to log.
|
137
137
|
|
138
138
|
"""
|
139
|
-
print(msg)
|
140
139
|
Logger.get_logger().info(msg)
|
141
140
|
|
142
141
|
@staticmethod
|
@@ -148,7 +147,6 @@ class Logger:
|
|
148
147
|
msg: Message to log.
|
149
148
|
|
150
149
|
"""
|
151
|
-
print(msg)
|
152
150
|
Logger.get_logger().warning(msg)
|
153
151
|
|
154
152
|
@staticmethod
|
{mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250127.521.dist-info → mct_nightly-2.3.0.20250129.508.dist-info}/top_level.txt
RENAMED
File without changes
|