mct-nightly 2.3.0.20250525.629__py3-none-any.whl → 2.3.0.20250527.555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/RECORD +18 -16
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/__init__.py +1 -0
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +2 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +6 -11
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/__init__.py +14 -0
- model_compression_toolkit/core/common/mixed_precision/{sensitivity_evaluation.py → sensitivity_eval/metric_calculators.py} +149 -244
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +168 -0
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +2 -6
- {mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{distance_weighting.py → sensitivity_eval/distance_weighting.py} +0 -0
- /model_compression_toolkit/core/common/mixed_precision/{set_layer_to_bitwidth.py → sensitivity_eval/set_layer_to_bitwidth.py} +0 -0
{mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/RECORD
RENAMED
@@ -1,11 +1,11 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250527.555.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=ac_6iGXJR83ii1qVJhusgDNmQ7il3U3QYpAm-wdLf14,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
6
6
|
model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
|
7
7
|
model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
|
8
|
-
model_compression_toolkit/core/__init__.py,sha256=
|
8
|
+
model_compression_toolkit/core/__init__.py,sha256=phfdtc09uruSyOpWRaUMUeMNRSwYB5q9NBus3cqcjIM,2113
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
10
10
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
|
11
11
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
|
@@ -60,18 +60,15 @@ model_compression_toolkit/core/common/matchers/edge_matcher.py,sha256=bS9KIBhB6Y
|
|
60
60
|
model_compression_toolkit/core/common/matchers/function.py,sha256=kMwcinxn_PInvetNh_L_lqGXT1hoi3f97PqBpjqfXoA,1773
|
61
61
|
model_compression_toolkit/core/common/matchers/node_matcher.py,sha256=63cMwa5YbQ5LKZy8-KFmdchVc3N7mpDJ6fNDt_uAQsk,2745
|
62
62
|
model_compression_toolkit/core/common/matchers/walk_matcher.py,sha256=xqfLKk6xZt72hSnND_HoX5ESOooNMypb5VOZkVsJ_nw,1111
|
63
|
-
model_compression_toolkit/core/common/mixed_precision/__init__.py,sha256=
|
64
|
-
model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256=
|
63
|
+
model_compression_toolkit/core/common/mixed_precision/__init__.py,sha256=Jm6pls3QUCMQ9d86KOYxOq05br_k130ByGHLCojIZ_M,766
|
64
|
+
model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256=npqLPyk5xXR11M_zdImtSALc5vJv9N4fEapaludKLBw,7139
|
65
65
|
model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
|
66
66
|
model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
|
67
|
-
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
|
68
67
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
|
69
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=
|
68
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rdtxPmRhjrC160O3fqAjDzGxpMeM49hYhmlnf_Kwqds,5416
|
70
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=axgAypzsiCOw04ZOtOEjK4riuNsaEU2qU6KkWnEXtMo,4951
|
71
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
72
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=
|
73
|
-
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4uhUXKgwyMrJqEVK5uJzVr67GI5YzDTHLveV4maB7z0,28079
|
74
|
-
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
|
70
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=1877xOUdgpWrXWyhdX1pJOePuopq43L71WqBFMqzyR4,6418
|
71
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TAxA9BKxINwUQfJpmf2Qghz-5DTbesuf1Pe1L0Tc-j4,28157
|
75
72
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MY8df-c_kITEr_7hOctaxhdiq29hSTA0La9Qo0oTJJY,9678
|
76
73
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
74
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
@@ -79,6 +76,11 @@ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools
|
|
79
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=ZY5yFIDzbaqIk0UzakDBObfsVevn4fydqAfAm4RCikY,4058
|
80
77
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
78
|
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
|
79
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/__init__.py,sha256=5yxITHNJcCfeGKdIpAYbNbKDoXUSvENuRQm3OQu8Qf4,697
|
80
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
|
81
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py,sha256=W4CySFtN874npcM9j9wu1PVrv7IZHLyKdLOPrTsCNQg,22209
|
82
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py,sha256=5l0qP0mZ061xh3rjqTJZcLD2mMKC-hfSnNAN0OmSusk,8938
|
83
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
|
82
84
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
83
85
|
model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
|
84
86
|
model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
|
@@ -132,7 +134,7 @@ model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=s
|
|
132
134
|
model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=b05ZwQ2CwG0Q-yqs9A1uHfP8o17aGEZFCeJNP1p4IWk,4450
|
133
135
|
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=b5clhUWGoDaQLn2pDCeYkV0FomVebcKS8pMXtQTTzIg,4679
|
134
136
|
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=C_nwhhitTd1pCto0nHZPn3fjIMOeDD7VIciumTR3s6k,5641
|
135
|
-
model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=
|
137
|
+
model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=zIkhOPF6K5aIgMExpD7HFT9UZSDpvXh51F6V-qZ7H-4,9048
|
136
138
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=LaGhYES7HgIDf9Bi2KAG_mBzAWuum0J6AGmAFPC8wwo,10478
|
137
139
|
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=E0ZA4edimJwpHh9twI5gafcoJ9fX5F1JX2QUOkUOKEw,6250
|
138
140
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
@@ -528,7 +530,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
530
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
531
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
532
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
533
|
+
mct_nightly-2.3.0.20250527.555.dist-info/METADATA,sha256=m5m0MizrO50qbrB0RkMCLt9s317qhSe3TcCD9otx0lQ,25135
|
534
|
+
mct_nightly-2.3.0.20250527.555.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
535
|
+
mct_nightly-2.3.0.20250527.555.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
536
|
+
mct_nightly-2.3.0.20250527.555.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250527.000555"
|
@@ -25,5 +25,5 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
26
26
|
from model_compression_toolkit.core.keras.resource_utilization_data_facade import keras_resource_utilization_data
|
27
27
|
from model_compression_toolkit.core.pytorch.resource_utilization_data_facade import pytorch_resource_utilization_data
|
28
|
-
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
|
28
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.distance_weighting import MpDistanceWeighting
|
29
29
|
|
@@ -76,7 +76,8 @@ def set_bit_widths(mixed_precision_enable: bool,
|
|
76
76
|
for n in graph.nodes:
|
77
77
|
assert len(n.candidates_quantization_cfg) == 1
|
78
78
|
n.final_weights_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].weights_quantization_cfg)
|
79
|
-
|
79
|
+
if not n.is_quantization_preserving():
|
80
|
+
n.final_activation_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].activation_quantization_cfg)
|
80
81
|
|
81
82
|
return graph
|
82
83
|
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py
CHANGED
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
|
17
17
|
from enum import Enum
|
18
18
|
from typing import List, Callable, Optional
|
19
19
|
from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
20
|
-
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
|
20
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.distance_weighting import MpDistanceWeighting
|
21
21
|
|
22
22
|
|
23
23
|
class MpMetricNormalization(Enum):
|
@@ -14,10 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from enum import Enum
|
17
|
-
from typing import List, Callable
|
17
|
+
from typing import List, Callable
|
18
18
|
|
19
19
|
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common import Graph
|
21
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_searc
|
|
25
25
|
MixedPrecisionSearchManager
|
26
26
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
27
27
|
ResourceUtilization
|
28
|
-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
|
28
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.sensitivity_evaluation import SensitivityEvaluation
|
29
29
|
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
|
30
30
|
greedy_solution_refinement_procedure
|
31
31
|
|
@@ -79,14 +79,9 @@ def search_bit_width(graph: Graph,
|
|
79
79
|
|
80
80
|
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
|
81
81
|
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
|
82
|
-
se = SensitivityEvaluation(
|
83
|
-
|
84
|
-
|
85
|
-
representative_data_gen=representative_data_gen,
|
86
|
-
fw_info=fw_info,
|
87
|
-
fw_impl=fw_impl,
|
88
|
-
disable_activation_for_metric=disable_activation_for_metric,
|
89
|
-
hessian_info_service=hessian_info_service)
|
82
|
+
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
|
83
|
+
fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
|
84
|
+
hessian_info_service=hessian_info_service)
|
90
85
|
|
91
86
|
if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
|
92
87
|
raise NotImplementedError()
|
@@ -21,7 +21,7 @@ from collections import defaultdict
|
|
21
21
|
|
22
22
|
from tqdm import tqdm
|
23
23
|
|
24
|
-
from typing import Dict, List, Tuple, Optional
|
24
|
+
from typing import Dict, List, Tuple, Optional
|
25
25
|
|
26
26
|
import numpy as np
|
27
27
|
|
@@ -39,7 +39,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_ru_he
|
|
39
39
|
MixedPrecisionRUHelper
|
40
40
|
from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
|
41
41
|
MixedPrecisionIntegerLPSolver
|
42
|
-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
|
42
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.sensitivity_evaluation import SensitivityEvaluation
|
43
43
|
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
44
44
|
from model_compression_toolkit.logger import Logger
|
45
45
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,40 +12,59 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import contextlib
|
16
|
-
import copy
|
17
|
-
import itertools
|
18
|
-
|
19
15
|
import numpy as np
|
20
|
-
from typing import Callable, Any, List, Tuple
|
16
|
+
from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
|
21
17
|
|
22
|
-
from model_compression_toolkit.core import
|
18
|
+
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, FrameworkInfo
|
23
19
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
24
|
-
from model_compression_toolkit.core.common.
|
25
|
-
|
26
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
27
|
-
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
|
20
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
21
|
+
HessianScoresGranularity
|
28
22
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
23
|
+
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
|
29
24
|
from model_compression_toolkit.logger import Logger
|
30
|
-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, \
|
31
|
-
HessianScoresGranularity, HessianInfoService
|
32
25
|
|
33
26
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
27
|
+
@runtime_checkable
|
28
|
+
class MetricCalculator(Protocol):
|
29
|
+
""" Abstract class for metric calculators. """
|
30
|
+
# all interest points (including graph outputs)
|
31
|
+
all_interest_points: list
|
32
|
+
|
33
|
+
def compute(self, mp_model) -> float:
|
34
|
+
""" Compute the metric for the given model. """
|
35
|
+
...
|
36
|
+
|
39
37
|
|
38
|
+
class CustomMetricCalculator(MetricCalculator):
|
39
|
+
""" Calculate metric with custom function applied on graph outputs. """
|
40
|
+
|
41
|
+
def __init__(self, graph: Graph, custom_metric_fn: Callable):
|
42
|
+
"""
|
43
|
+
Args:
|
44
|
+
graph: input graph.
|
45
|
+
custom_metric_fn: custom metric function, that accepts the model as input and return float scalar metric.
|
46
|
+
"""
|
47
|
+
self.all_interest_points = [n.node for n in graph.get_outputs()]
|
48
|
+
self.metric_fn = custom_metric_fn
|
49
|
+
|
50
|
+
def compute(self, mp_model: Any) -> float:
|
51
|
+
""" Compute the metric for the given model. """
|
52
|
+
sensitivity_metric = self.metric_fn(mp_model)
|
53
|
+
if not isinstance(sensitivity_metric, (float, np.floating)):
|
54
|
+
raise TypeError(
|
55
|
+
f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
|
56
|
+
return sensitivity_metric
|
57
|
+
|
58
|
+
|
59
|
+
class DistanceMetricCalculator(MetricCalculator):
|
60
|
+
""" Calculator for distance-based metrics. """
|
40
61
|
def __init__(self,
|
41
62
|
graph: Graph,
|
42
|
-
|
63
|
+
mp_config: MixedPrecisionQuantizationConfig,
|
43
64
|
representative_data_gen: Callable,
|
44
65
|
fw_info: FrameworkInfo,
|
45
66
|
fw_impl: Any,
|
46
|
-
|
47
|
-
hessian_info_service: HessianInfoService = None
|
48
|
-
):
|
67
|
+
hessian_info_service: HessianInfoService = None):
|
49
68
|
"""
|
50
69
|
Initiates all relevant objects to manage a sensitivity evaluation for MP search.
|
51
70
|
Create an object that allows to compute the sensitivity metric of an MP model (the sensitivity
|
@@ -59,23 +78,21 @@ class SensitivityEvaluation:
|
|
59
78
|
|
60
79
|
Args:
|
61
80
|
graph: Graph to search for its MP configuration.
|
81
|
+
mp_config: MP Quantization configuration for how the graph should be quantized.
|
62
82
|
fw_info: FrameworkInfo object about the specific framework
|
63
83
|
(e.g., attributes of different layers' weights to quantize).
|
64
|
-
quant_config: MP Quantization configuration for how the graph should be quantized.
|
65
|
-
representative_data_gen: Dataset used for getting batches for inference.
|
66
84
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
67
|
-
|
85
|
+
representative_data_gen: Dataset used for getting batches for inference.
|
68
86
|
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
|
69
|
-
|
70
87
|
"""
|
71
88
|
self.graph = graph
|
72
|
-
self.
|
89
|
+
self.mp_config = mp_config
|
73
90
|
self.representative_data_gen = representative_data_gen
|
74
91
|
self.fw_info = fw_info
|
75
92
|
self.fw_impl = fw_impl
|
76
|
-
|
77
|
-
if self.
|
78
|
-
if not isinstance(hessian_info_service, HessianInfoService):
|
93
|
+
|
94
|
+
if self.mp_config.use_hessian_based_scores:
|
95
|
+
if not isinstance(hessian_info_service, HessianInfoService): # pragma: no cover
|
79
96
|
Logger.critical(
|
80
97
|
f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
|
81
98
|
self.hessian_info_service = hessian_info_service
|
@@ -84,42 +101,35 @@ class SensitivityEvaluation:
|
|
84
101
|
|
85
102
|
# Get interest points and output points set for distance measurement and set other helper datasets
|
86
103
|
# We define a separate set of output nodes of the model for the purpose of sensitivity computation.
|
87
|
-
self.interest_points = get_mp_interest_points(graph,
|
88
|
-
|
89
|
-
|
90
|
-
# If using a custom metric - return only model outputs
|
91
|
-
if self.quant_config.custom_metric_fn is not None:
|
92
|
-
self.interest_points = []
|
104
|
+
self.interest_points = self.get_mp_interest_points(graph,
|
105
|
+
fw_impl.count_node_for_mixed_precision_interest_points,
|
106
|
+
mp_config.num_interest_points_factor)
|
93
107
|
|
94
108
|
# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
|
95
109
|
# because hessian weights already do normalization.
|
96
|
-
use_normalized_mse = self.
|
97
|
-
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points,
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
|
110
|
+
use_normalized_mse = self.mp_config.use_hessian_based_scores is False
|
111
|
+
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points,
|
112
|
+
use_normalized_mse)
|
113
|
+
|
114
|
+
output_points = self.get_output_nodes_for_metric(graph)
|
115
|
+
self.all_interest_points = self.interest_points + output_points
|
116
|
+
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(output_points,
|
104
117
|
use_normalized_mse)
|
105
118
|
|
119
|
+
self.ref_model, _ = fw_impl.model_builder(graph, mode=ModelBuilderMode.FLOAT,
|
120
|
+
append2output=self.all_interest_points)
|
121
|
+
|
106
122
|
# Setting lists with relative position of the interest points
|
107
123
|
# and output points in the list of all mp model activation tensors
|
108
124
|
graph_sorted_nodes = self.graph.get_topo_sorted_nodes()
|
109
|
-
all_out_tensors_indices = [graph_sorted_nodes.index(n) for n in self.
|
125
|
+
all_out_tensors_indices = [graph_sorted_nodes.index(n) for n in self.all_interest_points]
|
110
126
|
global_ipts_indices = [graph_sorted_nodes.index(n) for n in self.interest_points]
|
111
|
-
global_out_pts_indices = [graph_sorted_nodes.index(n) for n in
|
127
|
+
global_out_pts_indices = [graph_sorted_nodes.index(n) for n in output_points]
|
112
128
|
self.ips_act_indices = [all_out_tensors_indices.index(i) for i in global_ipts_indices]
|
113
129
|
self.out_ps_act_indices = [all_out_tensors_indices.index(i) for i in global_out_pts_indices]
|
114
130
|
|
115
|
-
# Build a mixed-precision model which can be configured to use different bitwidth in different layers.
|
116
|
-
# And a baseline model.
|
117
|
-
# Also, returns a mapping between a configurable graph's node and its matching layer(s)
|
118
|
-
# in the new built MP model.
|
119
|
-
self.baseline_model, self.model_mp, self.conf_node2layers = self._build_models()
|
120
|
-
|
121
131
|
# Build images batches for inference comparison and cat to framework type
|
122
|
-
images_batches = self._get_images_batches(
|
132
|
+
images_batches = self._get_images_batches(mp_config.num_of_images)
|
123
133
|
self.images_batches = [self.fw_impl.to_tensor(img) for img in images_batches]
|
124
134
|
|
125
135
|
# Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init.
|
@@ -128,12 +138,28 @@ class SensitivityEvaluation:
|
|
128
138
|
# Computing Hessian-based scores for weighted average distance metric computation (only if requested),
|
129
139
|
# and assigning distance_weighting method accordingly.
|
130
140
|
self.interest_points_hessians = None
|
131
|
-
if self.
|
141
|
+
if self.mp_config.use_hessian_based_scores is True:
|
132
142
|
self.interest_points_hessians = self._compute_hessian_based_scores()
|
133
|
-
self.
|
143
|
+
self.mp_config.distance_weighting_method = lambda d: self.interest_points_hessians
|
134
144
|
|
135
|
-
def
|
136
|
-
|
145
|
+
def compute(self, mp_model) -> float:
|
146
|
+
"""
|
147
|
+
Compute the metric for the given model.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
mp_model: MP configured model.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Computed metric.
|
154
|
+
"""
|
155
|
+
ipts_distances, out_pts_distances = self._compute_distance(mp_model)
|
156
|
+
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
|
157
|
+
self.mp_config.distance_weighting_method)
|
158
|
+
return sensitivity_metric
|
159
|
+
|
160
|
+
def _init_metric_points_lists(self,
|
161
|
+
points: List[BaseNode],
|
162
|
+
norm_mse: bool = False) -> Tuple[List[Callable], List[int]]:
|
137
163
|
"""
|
138
164
|
Initiates required lists for future use when computing the sensitivity metric.
|
139
165
|
Each point on which the metric is computed uses a dedicated distance function based on its type.
|
@@ -150,101 +176,19 @@ class SensitivityEvaluation:
|
|
150
176
|
axis_list = []
|
151
177
|
for n in points:
|
152
178
|
distance_fn, axis = self.fw_impl.get_mp_node_distance_fn(n,
|
153
|
-
compute_distance_fn=self.
|
179
|
+
compute_distance_fn=self.mp_config.compute_distance_fn,
|
154
180
|
norm_mse=norm_mse)
|
155
181
|
distance_fns_list.append(distance_fn)
|
156
182
|
# Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
|
157
183
|
axis_list.append(axis if distance_fn == compute_kl_divergence else None)
|
158
184
|
return distance_fns_list, axis_list
|
159
185
|
|
160
|
-
def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
|
161
|
-
"""
|
162
|
-
Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
|
163
|
-
is computed based on the similarity of the interest points' outputs between the MP model
|
164
|
-
and the float model or a custom metric if given).
|
165
|
-
Quantization for any configurable activation / weight that were not passed is disabled.
|
166
|
-
|
167
|
-
Args:
|
168
|
-
mp_a_cfg: Bitwidth activations configuration for the MP model.
|
169
|
-
mp_w_cfg: Bitwidth weights configuration for the MP model.
|
170
|
-
|
171
|
-
Returns:
|
172
|
-
The sensitivity metric of the MP model for a given configuration.
|
173
|
-
"""
|
174
|
-
|
175
|
-
with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
|
176
|
-
sensitivity_metric = self._compute_metric()
|
177
|
-
|
178
|
-
return sensitivity_metric
|
179
|
-
|
180
|
-
def _compute_metric(self) -> float:
|
181
|
-
"""
|
182
|
-
Compute sensitivity metric on a configured mp model.
|
183
|
-
|
184
|
-
Returns:
|
185
|
-
Sensitivity metric.
|
186
|
-
"""
|
187
|
-
if self.quant_config.custom_metric_fn:
|
188
|
-
sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
|
189
|
-
if not isinstance(sensitivity_metric, (float, np.floating)):
|
190
|
-
raise TypeError(
|
191
|
-
f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
|
192
|
-
return sensitivity_metric
|
193
|
-
|
194
|
-
# compute default metric
|
195
|
-
ipts_distances, out_pts_distances = self._compute_distance()
|
196
|
-
sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
|
197
|
-
self.quant_config.distance_weighting_method)
|
198
|
-
return sensitivity_metric
|
199
|
-
|
200
186
|
def _init_baseline_tensors_list(self):
|
201
187
|
"""
|
202
188
|
Evaluates the baseline model on all images and returns the obtained lists of tensors in a list for later use.
|
203
189
|
"""
|
204
|
-
return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.
|
205
|
-
|
206
|
-
|
207
|
-
def _build_models(self) -> Any:
|
208
|
-
"""
|
209
|
-
Builds two models - an MP model with configurable layers and a baseline, float model.
|
210
|
-
|
211
|
-
Returns: A tuple with two models built from the given graph: a baseline model (with baseline configuration) and
|
212
|
-
an MP model (which can be configured for a specific bitwidth configuration).
|
213
|
-
Note that the type of the returned models is dependent on the used framework (TF/Pytorch).
|
214
|
-
"""
|
215
|
-
|
216
|
-
evaluation_graph = copy.deepcopy(self.graph)
|
217
|
-
|
218
|
-
# Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
|
219
|
-
# be added to the model).
|
220
|
-
for n in evaluation_graph.get_topo_sorted_nodes():
|
221
|
-
if self.disable_activation_for_metric or not n.has_configurable_activation():
|
222
|
-
for c in n.candidates_quantization_cfg:
|
223
|
-
c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
224
|
-
if not n.has_any_configurable_weight():
|
225
|
-
for c in n.candidates_quantization_cfg:
|
226
|
-
c.weights_quantization_cfg.disable_all_weights_quantization()
|
227
|
-
|
228
|
-
model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
|
229
|
-
mode=ModelBuilderMode.MIXEDPRECISION,
|
230
|
-
append2output=self.interest_points + self.output_points,
|
231
|
-
fw_info=self.fw_info)
|
232
|
-
|
233
|
-
# Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
|
234
|
-
# Note: from this point mp_model is not in sync with graph quantization configuration for configurable nodes.
|
235
|
-
for layer in itertools.chain(*conf_node2layers.values()):
|
236
|
-
if isinstance(layer, self.fw_impl.activation_quant_layer_cls):
|
237
|
-
set_activation_quant_layer_to_bitwidth(layer, None, self.fw_impl)
|
238
|
-
else:
|
239
|
-
assert isinstance(layer, self.fw_impl.weights_quant_layer_cls)
|
240
|
-
set_weights_quant_layer_to_bitwidth(layer, None, self.fw_impl)
|
241
|
-
|
242
|
-
# Build a baseline model (to compute distances from).
|
243
|
-
baseline_model, _ = self.fw_impl.model_builder(evaluation_graph,
|
244
|
-
mode=ModelBuilderMode.FLOAT,
|
245
|
-
append2output=self.interest_points + self.output_points)
|
246
|
-
|
247
|
-
return baseline_model, model_mp, conf_node2layers
|
190
|
+
return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.ref_model, images))
|
191
|
+
for images in self.images_batches]
|
248
192
|
|
249
193
|
def _compute_hessian_based_scores(self) -> np.ndarray:
|
250
194
|
"""
|
@@ -257,61 +201,21 @@ class SensitivityEvaluation:
|
|
257
201
|
# Create a request for Hessian approximation scores with specific configurations
|
258
202
|
# (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
|
259
203
|
fw_dataloader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen,
|
260
|
-
batch_size=self.
|
204
|
+
batch_size=self.mp_config.hessian_batch_size)
|
261
205
|
hessian_info_request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
|
262
206
|
granularity=HessianScoresGranularity.PER_TENSOR,
|
263
207
|
target_nodes=self.interest_points,
|
264
208
|
data_loader=fw_dataloader,
|
265
|
-
n_samples=self.
|
209
|
+
n_samples=self.mp_config.num_of_images)
|
266
210
|
|
267
211
|
# Fetch the Hessian approximation scores for the current interest point
|
268
212
|
nodes_approximations = self.hessian_info_service.fetch_hessian(request=hessian_info_request)
|
269
|
-
approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points],
|
213
|
+
approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points],
|
214
|
+
axis=1) # samples X nodes
|
270
215
|
|
271
216
|
# Return the mean approximation value across all images for each interest point
|
272
217
|
return np.mean(approx_by_image, axis=0)
|
273
218
|
|
274
|
-
@contextlib.contextmanager
|
275
|
-
def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
|
276
|
-
"""
|
277
|
-
Context manager to configure specific configurable layers of the mp model. At exit, configuration is
|
278
|
-
automatically restored to un-quantized.
|
279
|
-
|
280
|
-
Args:
|
281
|
-
mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
|
282
|
-
mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
|
283
|
-
|
284
|
-
"""
|
285
|
-
if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
|
286
|
-
mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
|
287
|
-
raise ValueError(f'Requested configuration is either empty or contain only None values.')
|
288
|
-
|
289
|
-
# defined here so that it can't be used directly
|
290
|
-
def apply_bitwidth_config(a_cfg, w_cfg):
|
291
|
-
node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
|
292
|
-
for n in node_names:
|
293
|
-
node_quant_layers = self.conf_node2layers.get(n)
|
294
|
-
if node_quant_layers is None: # pragma: no cover
|
295
|
-
raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
|
296
|
-
for qlayer in node_quant_layers:
|
297
|
-
assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
|
298
|
-
self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
|
299
|
-
if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
|
300
|
-
set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
|
301
|
-
a_cfg.pop(n)
|
302
|
-
elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
|
303
|
-
set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
|
304
|
-
w_cfg.pop(n)
|
305
|
-
if a_cfg or w_cfg:
|
306
|
-
raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
|
307
|
-
f'weights config {w_cfg}.')
|
308
|
-
|
309
|
-
apply_bitwidth_config(mp_a_cfg.copy(), mp_w_cfg.copy())
|
310
|
-
try:
|
311
|
-
yield
|
312
|
-
finally:
|
313
|
-
apply_bitwidth_config({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
|
314
|
-
|
315
219
|
def _compute_points_distance(self,
|
316
220
|
baseline_tensors: List[Any],
|
317
221
|
mp_tensors: List[Any],
|
@@ -338,7 +242,7 @@ class SensitivityEvaluation:
|
|
338
242
|
|
339
243
|
return np.asarray(distance_v)
|
340
244
|
|
341
|
-
def _compute_distance(self) -> Tuple[np.ndarray, np.ndarray]:
|
245
|
+
def _compute_distance(self, mp_model) -> Tuple[np.ndarray, np.ndarray]:
|
342
246
|
"""
|
343
247
|
Computing the interest points distance and the output points distance, and using them to build a
|
344
248
|
unified distance vector.
|
@@ -352,7 +256,7 @@ class SensitivityEvaluation:
|
|
352
256
|
# Compute the distance matrix for num_of_images images.
|
353
257
|
for images, baseline_tensors in zip(self.images_batches, self.baseline_tensors_list):
|
354
258
|
# when using model.predict(), it does not use the QuantizeWrapper functionality
|
355
|
-
mp_tensors = self.fw_impl.sensitivity_eval_inference(
|
259
|
+
mp_tensors = self.fw_impl.sensitivity_eval_inference(mp_model, images)
|
356
260
|
mp_tensors = self.fw_impl.to_numpy(mp_tensors)
|
357
261
|
|
358
262
|
# Compute distance: similarity between the baseline model to the float model
|
@@ -440,77 +344,78 @@ class SensitivityEvaluation:
|
|
440
344
|
samples_count += batch_size
|
441
345
|
else:
|
442
346
|
if samples_count < num_of_images:
|
443
|
-
Logger.warning(
|
444
|
-
|
347
|
+
Logger.warning(
|
348
|
+
f'Not enough images in representative dataset to generate {num_of_images} data points, '
|
349
|
+
f'only {samples_count} were generated')
|
445
350
|
return images_batches
|
446
351
|
|
352
|
+
@classmethod
|
353
|
+
def get_mp_interest_points(cls, graph: Graph,
|
354
|
+
interest_points_classifier: Callable,
|
355
|
+
num_ip_factor: float) -> List[BaseNode]:
|
356
|
+
"""
|
357
|
+
Gets a list of interest points for the mixed precision metric computation.
|
358
|
+
The list is constructed from a filtered set of nodes in the graph.
|
359
|
+
Note that the output layers are separated from the interest point set for metric computation purposes.
|
447
360
|
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
The list is constructed from a filtered set of nodes in the graph.
|
454
|
-
Note that the output layers are separated from the interest point set for metric computation purposes.
|
455
|
-
|
456
|
-
Args:
|
457
|
-
graph: Graph to search for its MP configuration.
|
458
|
-
interest_points_classifier: A function that indicates whether a given node in considered as a potential
|
459
|
-
interest point for mp metric computation purposes.
|
460
|
-
num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
|
461
|
-
|
462
|
-
Returns: A list of interest points (nodes in the graph).
|
463
|
-
|
464
|
-
"""
|
465
|
-
sorted_nodes = graph.get_topo_sorted_nodes()
|
466
|
-
ip_nodes = [n for n in sorted_nodes if interest_points_classifier(n)]
|
467
|
-
|
468
|
-
interest_points_nodes = bound_num_interest_points(ip_nodes, num_ip_factor)
|
361
|
+
Args:
|
362
|
+
graph: Graph to search for its MP configuration.
|
363
|
+
interest_points_classifier: A function that indicates whether a given node in considered as a potential
|
364
|
+
interest point for mp metric computation purposes.
|
365
|
+
num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
|
469
366
|
|
470
|
-
|
471
|
-
output_nodes = [n.node for n in graph.get_outputs()]
|
367
|
+
Returns: A list of interest points (nodes in the graph).
|
472
368
|
|
473
|
-
|
369
|
+
"""
|
370
|
+
sorted_nodes = graph.get_topo_sorted_nodes()
|
371
|
+
ip_nodes = [n for n in sorted_nodes if interest_points_classifier(n)]
|
474
372
|
|
475
|
-
|
373
|
+
interest_points_nodes = cls.bound_num_interest_points(ip_nodes, num_ip_factor)
|
476
374
|
|
375
|
+
# We exclude output nodes from the set of interest points since they are used separately in the sensitivity evaluation.
|
376
|
+
output_nodes = [n.node for n in graph.get_outputs()]
|
477
377
|
|
478
|
-
|
479
|
-
"""
|
480
|
-
Returns a list of output nodes that are also quantized (either kernel weights attribute or activation)
|
481
|
-
to be used as a set of output points in the distance metric computation.
|
378
|
+
interest_points = [n for n in interest_points_nodes if n not in output_nodes]
|
482
379
|
|
483
|
-
|
484
|
-
graph: Graph to search for its MP configuration.
|
380
|
+
return interest_points
|
485
381
|
|
486
|
-
|
382
|
+
@staticmethod
|
383
|
+
def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]:
|
384
|
+
"""
|
385
|
+
Returns a list of output nodes that are also quantized (either kernel weights attribute or activation)
|
386
|
+
to be used as a set of output points in the distance metric computation.
|
487
387
|
|
488
|
-
|
388
|
+
Args:
|
389
|
+
graph: Graph to search for its MP configuration.
|
489
390
|
|
490
|
-
|
491
|
-
if (graph.fw_info.is_kernel_op(n.node.type) and
|
492
|
-
n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
|
493
|
-
n.node.is_activation_quantization_enabled()]
|
391
|
+
Returns: A list of output nodes.
|
494
392
|
|
393
|
+
"""
|
495
394
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
395
|
+
return [n.node for n in graph.get_outputs()
|
396
|
+
if (graph.fw_info.is_kernel_op(n.node.type) and
|
397
|
+
n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
|
398
|
+
n.node.is_activation_quantization_enabled()]
|
500
399
|
|
501
|
-
|
502
|
-
|
503
|
-
|
400
|
+
@staticmethod
|
401
|
+
def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
|
402
|
+
"""
|
403
|
+
Filters the list of interest points and returns a shorter list with number of interest points smaller than some
|
404
|
+
default threshold.
|
504
405
|
|
505
|
-
|
406
|
+
Args:
|
407
|
+
sorted_ip_list: List of nodes which are considered as interest points for the metric computation.
|
408
|
+
num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
|
506
409
|
|
507
|
-
|
508
|
-
if num_ip_factor < 1.0:
|
509
|
-
num_interest_points = int(num_ip_factor * len(sorted_ip_list))
|
510
|
-
Logger.info(f'Using {num_interest_points} for mixed-precision metric evaluation out of total '
|
511
|
-
f'{len(sorted_ip_list)} potential interest points.')
|
512
|
-
# Take num_interest_points evenly spaced interest points from the original list
|
513
|
-
indices = np.round(np.linspace(0, len(sorted_ip_list) - 1, num_interest_points)).astype(int)
|
514
|
-
return [sorted_ip_list[i] for i in indices]
|
410
|
+
Returns: A new list of interest points (list of nodes).
|
515
411
|
|
516
|
-
|
412
|
+
"""
|
413
|
+
if num_ip_factor < 1.0:
|
414
|
+
num_interest_points = int(num_ip_factor * len(sorted_ip_list))
|
415
|
+
Logger.info(f'Using {num_interest_points} for mixed-precision metric evaluation out of total '
|
416
|
+
f'{len(sorted_ip_list)} potential interest points.')
|
417
|
+
# Take num_interest_points evenly spaced interest points from the original list
|
418
|
+
indices = np.round(np.linspace(0, len(sorted_ip_list) - 1, num_interest_points)).astype(int)
|
419
|
+
return [sorted_ip_list[i] for i in indices]
|
420
|
+
|
421
|
+
return sorted_ip_list
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import contextlib
|
16
|
+
import copy
|
17
|
+
import itertools
|
18
|
+
|
19
|
+
from typing import Callable, Any, Tuple, Dict, Optional
|
20
|
+
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
|
22
|
+
from model_compression_toolkit.core.common import Graph
|
23
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.metric_calculators import \
|
24
|
+
CustomMetricCalculator, DistanceMetricCalculator
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.set_layer_to_bitwidth import \
|
26
|
+
set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
|
27
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
28
|
+
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
29
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
30
|
+
|
31
|
+
|
32
|
+
class SensitivityEvaluation:
|
33
|
+
"""
|
34
|
+
Sensitivity evaluation of a bit-width configuration for Mixed Precision search.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self,
|
38
|
+
graph: Graph,
|
39
|
+
mp_config: MixedPrecisionQuantizationConfig,
|
40
|
+
representative_data_gen: Callable,
|
41
|
+
fw_info: FrameworkInfo,
|
42
|
+
fw_impl: Any,
|
43
|
+
disable_activation_for_metric: bool = False,
|
44
|
+
hessian_info_service: HessianInfoService = None
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Args:
|
48
|
+
graph: Graph to search for its MP configuration.
|
49
|
+
fw_info: FrameworkInfo object about the specific framework
|
50
|
+
(e.g., attributes of different layers' weights to quantize).
|
51
|
+
mp_config: MP Quantization configuration for how the graph should be quantized.
|
52
|
+
representative_data_gen: Dataset used for getting batches for inference.
|
53
|
+
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
54
|
+
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
|
55
|
+
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
|
56
|
+
|
57
|
+
"""
|
58
|
+
self.mp_config = mp_config
|
59
|
+
self.representative_data_gen = representative_data_gen
|
60
|
+
self.fw_info = fw_info
|
61
|
+
self.fw_impl = fw_impl
|
62
|
+
|
63
|
+
if self.mp_config.custom_metric_fn:
|
64
|
+
self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
|
65
|
+
else:
|
66
|
+
self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
|
67
|
+
fw_info=fw_info, fw_impl=fw_impl,
|
68
|
+
hessian_info_service=hessian_info_service)
|
69
|
+
|
70
|
+
# Build a mixed-precision model which can be configured to use different bitwidth in different layers.
|
71
|
+
# Also, returns a mapping between a configurable graph's node and its matching layer(s) in the built MP model.
|
72
|
+
self.mp_model, self.conf_node2layers = self._build_mp_model(graph, self.metric_calculator.all_interest_points,
|
73
|
+
disable_activation_for_metric)
|
74
|
+
|
75
|
+
def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
|
76
|
+
"""
|
77
|
+
Compute the sensitivity metric of the MP model for a given configuration.
|
78
|
+
Quantization for any configurable activation / weight that were not passed is disabled.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
mp_a_cfg: Bitwidth activations configuration for the MP model.
|
82
|
+
mp_w_cfg: Bitwidth weights configuration for the MP model.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
The sensitivity metric of the MP model for a given configuration.
|
86
|
+
"""
|
87
|
+
with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
|
88
|
+
sensitivity_metric = self.metric_calculator.compute(self.mp_model)
|
89
|
+
|
90
|
+
return sensitivity_metric
|
91
|
+
|
92
|
+
def _build_mp_model(self, graph, outputs, disable_activations: bool) -> Tuple[Any, dict]:
|
93
|
+
"""
|
94
|
+
Builds an MP model with configurable layers.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
MP model and a mapping from configurable graph nodes to their corresponding quantization layer(s)
|
98
|
+
in the MP model.
|
99
|
+
"""
|
100
|
+
evaluation_graph = copy.deepcopy(graph)
|
101
|
+
|
102
|
+
# Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
|
103
|
+
# be added to the model).
|
104
|
+
for n in evaluation_graph.get_topo_sorted_nodes():
|
105
|
+
if disable_activations or not n.has_configurable_activation():
|
106
|
+
for c in n.candidates_quantization_cfg:
|
107
|
+
c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
108
|
+
if not n.has_any_configurable_weight():
|
109
|
+
for c in n.candidates_quantization_cfg:
|
110
|
+
c.weights_quantization_cfg.disable_all_weights_quantization()
|
111
|
+
|
112
|
+
model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
|
113
|
+
mode=ModelBuilderMode.MIXEDPRECISION,
|
114
|
+
append2output=outputs,
|
115
|
+
fw_info=self.fw_info)
|
116
|
+
|
117
|
+
# Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
|
118
|
+
for layer in itertools.chain(*conf_node2layers.values()):
|
119
|
+
if isinstance(layer, self.fw_impl.activation_quant_layer_cls):
|
120
|
+
set_activation_quant_layer_to_bitwidth(layer, None, self.fw_impl)
|
121
|
+
else:
|
122
|
+
assert isinstance(layer, self.fw_impl.weights_quant_layer_cls)
|
123
|
+
set_weights_quant_layer_to_bitwidth(layer, None, self.fw_impl)
|
124
|
+
|
125
|
+
return model_mp, conf_node2layers
|
126
|
+
|
127
|
+
@contextlib.contextmanager
|
128
|
+
def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
|
129
|
+
"""
|
130
|
+
Context manager to configure specific configurable layers of the mp model. At exit, configuration is
|
131
|
+
automatically restored to un-quantized.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
|
135
|
+
mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
|
136
|
+
|
137
|
+
"""
|
138
|
+
if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
|
139
|
+
mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
|
140
|
+
raise ValueError(f'Requested configuration is either empty or contain only None values.')
|
141
|
+
|
142
|
+
# defined here so that it can't be used directly
|
143
|
+
def apply_bitwidth_config(a_cfg, w_cfg):
|
144
|
+
node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
|
145
|
+
for n in node_names:
|
146
|
+
node_quant_layers = self.conf_node2layers.get(n)
|
147
|
+
if node_quant_layers is None: # pragma: no cover
|
148
|
+
raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
|
149
|
+
for qlayer in node_quant_layers:
|
150
|
+
assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
|
151
|
+
self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
|
152
|
+
if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
|
153
|
+
set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
|
154
|
+
a_cfg.pop(n)
|
155
|
+
elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
|
156
|
+
set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
|
157
|
+
w_cfg.pop(n)
|
158
|
+
if a_cfg or w_cfg:
|
159
|
+
raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
|
160
|
+
f'weights config {w_cfg}.')
|
161
|
+
|
162
|
+
apply_bitwidth_config(mp_a_cfg.copy(), mp_w_cfg.copy())
|
163
|
+
try:
|
164
|
+
yield
|
165
|
+
finally:
|
166
|
+
apply_bitwidth_config({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
|
167
|
+
|
168
|
+
|
@@ -42,13 +42,9 @@ def get_previous_node_with_activation_quantization(linear_node: BaseNode,
|
|
42
42
|
|
43
43
|
prev_node = prev_nodes[0]
|
44
44
|
|
45
|
-
|
45
|
+
prev_quant_node = graph.retrieve_preserved_quantization_node(prev_node)
|
46
46
|
|
47
|
-
|
48
|
-
if activation_quantization_config.enable_activation_quantization:
|
49
|
-
return prev_node
|
50
|
-
else:
|
51
|
-
return get_previous_node_with_activation_quantization(prev_node, graph)
|
47
|
+
return prev_quant_node if prev_quant_node.is_activation_quantization_enabled() else None
|
52
48
|
|
53
49
|
|
54
50
|
def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250525.629.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/top_level.txt
RENAMED
File without changes
|
File without changes
|