mct-nightly 2.3.0.20250526.601__py3-none-any.whl → 2.3.0.20250527.555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {mct_nightly-2.3.0.20250526.601.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250526.601.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/RECORD +18 -16
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/__init__.py +1 -1
  5. model_compression_toolkit/core/common/mixed_precision/__init__.py +1 -0
  6. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +2 -1
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +1 -1
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +6 -11
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +2 -2
  10. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/__init__.py +14 -0
  11. model_compression_toolkit/core/common/mixed_precision/{sensitivity_evaluation.py → sensitivity_eval/metric_calculators.py} +149 -244
  12. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +168 -0
  13. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +2 -6
  14. {mct_nightly-2.3.0.20250526.601.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/WHEEL +0 -0
  15. {mct_nightly-2.3.0.20250526.601.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/licenses/LICENSE.md +0 -0
  16. {mct_nightly-2.3.0.20250526.601.dist-info → mct_nightly-2.3.0.20250527.555.dist-info}/top_level.txt +0 -0
  17. /model_compression_toolkit/core/common/mixed_precision/{distance_weighting.py → sensitivity_eval/distance_weighting.py} +0 -0
  18. /model_compression_toolkit/core/common/mixed_precision/{set_layer_to_bitwidth.py → sensitivity_eval/set_layer_to_bitwidth.py} +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250526.601
3
+ Version: 2.3.0.20250527.555
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,11 +1,11 @@
1
- mct_nightly-2.3.0.20250526.601.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=h0yrmmeo04GsUcV-lK41wbKSmLv-C_RXbP5Bgqo0EOA,1557
1
+ mct_nightly-2.3.0.20250527.555.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=ac_6iGXJR83ii1qVJhusgDNmQ7il3U3QYpAm-wdLf14,1557
3
3
  model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
6
6
  model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
7
7
  model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
8
- model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
8
+ model_compression_toolkit/core/__init__.py,sha256=phfdtc09uruSyOpWRaUMUeMNRSwYB5q9NBus3cqcjIM,2113
9
9
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
10
10
  model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
11
11
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
@@ -60,18 +60,15 @@ model_compression_toolkit/core/common/matchers/edge_matcher.py,sha256=bS9KIBhB6Y
60
60
  model_compression_toolkit/core/common/matchers/function.py,sha256=kMwcinxn_PInvetNh_L_lqGXT1hoi3f97PqBpjqfXoA,1773
61
61
  model_compression_toolkit/core/common/matchers/node_matcher.py,sha256=63cMwa5YbQ5LKZy8-KFmdchVc3N7mpDJ6fNDt_uAQsk,2745
62
62
  model_compression_toolkit/core/common/matchers/walk_matcher.py,sha256=xqfLKk6xZt72hSnND_HoX5ESOooNMypb5VOZkVsJ_nw,1111
63
- model_compression_toolkit/core/common/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
64
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256=lB3cxQPQqpAH5tP6kqOxqv7RmOtf1YciIkvr9irvKq0,7084
63
+ model_compression_toolkit/core/common/mixed_precision/__init__.py,sha256=Jm6pls3QUCMQ9d86KOYxOq05br_k130ByGHLCojIZ_M,766
64
+ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256=npqLPyk5xXR11M_zdImtSALc5vJv9N4fEapaludKLBw,7139
65
65
  model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
66
66
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
67
- model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=BO4ouM_UVS9Fg0z95gLJSMz1ep6YQC5za_iXI_qW2yQ,5399
68
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=rdtxPmRhjrC160O3fqAjDzGxpMeM49hYhmlnf_Kwqds,5416
70
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=axgAypzsiCOw04ZOtOEjK4riuNsaEU2qU6KkWnEXtMo,4951
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=KhiHGpmN5QbpyJQnTZmXigdXFlSlRNqpOOyKGj1Fwek,6412
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=MXOK9WPy3fSt5uxsWYMF4szwwqWWgrlzNJdE9VIb-AQ,28145
73
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4uhUXKgwyMrJqEVK5uJzVr67GI5YzDTHLveV4maB7z0,28079
74
- model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
70
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=1877xOUdgpWrXWyhdX1pJOePuopq43L71WqBFMqzyR4,6418
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TAxA9BKxINwUQfJpmf2Qghz-5DTbesuf1Pe1L0Tc-j4,28157
75
72
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MY8df-c_kITEr_7hOctaxhdiq29hSTA0La9Qo0oTJJY,9678
76
73
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
74
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
@@ -79,6 +76,11 @@ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools
79
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=ZY5yFIDzbaqIk0UzakDBObfsVevn4fydqAfAm4RCikY,4058
80
77
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
78
  model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
79
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/__init__.py,sha256=5yxITHNJcCfeGKdIpAYbNbKDoXUSvENuRQm3OQu8Qf4,697
80
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
81
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py,sha256=W4CySFtN874npcM9j9wu1PVrv7IZHLyKdLOPrTsCNQg,22209
82
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py,sha256=5l0qP0mZ061xh3rjqTJZcLD2mMKC-hfSnNAN0OmSusk,8938
83
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
82
84
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
83
85
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
84
86
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -132,7 +134,7 @@ model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=s
132
134
  model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=b05ZwQ2CwG0Q-yqs9A1uHfP8o17aGEZFCeJNP1p4IWk,4450
133
135
  model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=b5clhUWGoDaQLn2pDCeYkV0FomVebcKS8pMXtQTTzIg,4679
134
136
  model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=C_nwhhitTd1pCto0nHZPn3fjIMOeDD7VIciumTR3s6k,5641
135
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=F8kK8yoYCGeTdXUsHGcM3T2tRdjSlcWg3UToGtovNOs,9196
137
+ model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=zIkhOPF6K5aIgMExpD7HFT9UZSDpvXh51F6V-qZ7H-4,9048
136
138
  model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=LaGhYES7HgIDf9Bi2KAG_mBzAWuum0J6AGmAFPC8wwo,10478
137
139
  model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=E0ZA4edimJwpHh9twI5gafcoJ9fX5F1JX2QUOkUOKEw,6250
138
140
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
@@ -528,7 +530,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
530
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
531
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
532
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250526.601.dist-info/METADATA,sha256=y5pozmwxQDw3vKdFGMhDfkQCjfugDgQPgIU_V58eWNw,25135
532
- mct_nightly-2.3.0.20250526.601.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
533
- mct_nightly-2.3.0.20250526.601.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250526.601.dist-info/RECORD,,
533
+ mct_nightly-2.3.0.20250527.555.dist-info/METADATA,sha256=m5m0MizrO50qbrB0RkMCLt9s317qhSe3TcCD9otx0lQ,25135
534
+ mct_nightly-2.3.0.20250527.555.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
535
+ mct_nightly-2.3.0.20250527.555.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
536
+ mct_nightly-2.3.0.20250527.555.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250526.000601"
30
+ __version__ = "2.3.0.20250527.000555"
@@ -25,5 +25,5 @@ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.core.keras.resource_utilization_data_facade import keras_resource_utilization_data
27
27
  from model_compression_toolkit.core.pytorch.resource_utilization_data_facade import pytorch_resource_utilization_data
28
- from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
28
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.distance_weighting import MpDistanceWeighting
29
29
 
@@ -12,3 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from .sensitivity_eval.distance_weighting import MpDistanceWeighting
@@ -76,7 +76,8 @@ def set_bit_widths(mixed_precision_enable: bool,
76
76
  for n in graph.nodes:
77
77
  assert len(n.candidates_quantization_cfg) == 1
78
78
  n.final_weights_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].weights_quantization_cfg)
79
- n.final_activation_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].activation_quantization_cfg)
79
+ if not n.is_quantization_preserving():
80
+ n.final_activation_quantization_cfg = copy.deepcopy(n.candidates_quantization_cfg[0].activation_quantization_cfg)
80
81
 
81
82
  return graph
82
83
 
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
17
17
  from enum import Enum
18
18
  from typing import List, Callable, Optional
19
19
  from model_compression_toolkit.constants import MP_DEFAULT_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
- from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
20
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.distance_weighting import MpDistanceWeighting
21
21
 
22
22
 
23
23
  class MpMetricNormalization(Enum):
@@ -14,10 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from enum import Enum
17
- from typing import List, Callable, Dict
17
+ from typing import List, Callable
18
18
 
19
19
  from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
20
- from model_compression_toolkit.core.common import Graph, BaseNode
20
+ from model_compression_toolkit.core.common import Graph
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.hessian import HessianInfoService
@@ -25,7 +25,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_searc
25
25
  MixedPrecisionSearchManager
26
26
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27
27
  ResourceUtilization
28
- from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
28
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.sensitivity_evaluation import SensitivityEvaluation
29
29
  from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
30
30
  greedy_solution_refinement_procedure
31
31
 
@@ -79,14 +79,9 @@ def search_bit_width(graph: Graph,
79
79
 
80
80
  # Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
81
81
  # even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
82
- se = SensitivityEvaluation(
83
- graph,
84
- mp_config,
85
- representative_data_gen=representative_data_gen,
86
- fw_info=fw_info,
87
- fw_impl=fw_impl,
88
- disable_activation_for_metric=disable_activation_for_metric,
89
- hessian_info_service=hessian_info_service)
82
+ se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
83
+ fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
84
+ hessian_info_service=hessian_info_service)
90
85
 
91
86
  if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
92
87
  raise NotImplementedError()
@@ -21,7 +21,7 @@ from collections import defaultdict
21
21
 
22
22
  from tqdm import tqdm
23
23
 
24
- from typing import Dict, List, Tuple, Optional, Set
24
+ from typing import Dict, List, Tuple, Optional
25
25
 
26
26
  import numpy as np
27
27
 
@@ -39,7 +39,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_ru_he
39
39
  MixedPrecisionRUHelper
40
40
  from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
41
41
  MixedPrecisionIntegerLPSolver
42
- from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
42
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.sensitivity_evaluation import SensitivityEvaluation
43
43
  from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
44
44
  from model_compression_toolkit.logger import Logger
45
45
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,40 +12,59 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import contextlib
16
- import copy
17
- import itertools
18
-
19
15
  import numpy as np
20
- from typing import Callable, Any, List, Tuple, Dict, Optional
16
+ from typing import runtime_checkable, Protocol, Callable, Any, List, Tuple
21
17
 
22
- from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
18
+ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, FrameworkInfo
23
19
  from model_compression_toolkit.core.common import Graph, BaseNode
24
- from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import \
25
- set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
26
- from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
27
- from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
20
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
21
+ HessianScoresGranularity
28
22
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
23
+ from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
29
24
  from model_compression_toolkit.logger import Logger
30
- from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, \
31
- HessianScoresGranularity, HessianInfoService
32
25
 
33
26
 
34
- class SensitivityEvaluation:
35
- """
36
- Class to wrap and manage the computation on distance metric for Mixed-Precision quantization search.
37
- It provides a function that evaluates the sensitivity of a bit-width configuration for the MP model.
38
- """
27
+ @runtime_checkable
28
+ class MetricCalculator(Protocol):
29
+ """ Abstract class for metric calculators. """
30
+ # all interest points (including graph outputs)
31
+ all_interest_points: list
32
+
33
+ def compute(self, mp_model) -> float:
34
+ """ Compute the metric for the given model. """
35
+ ...
36
+
39
37
 
38
+ class CustomMetricCalculator(MetricCalculator):
39
+ """ Calculate metric with custom function applied on graph outputs. """
40
+
41
+ def __init__(self, graph: Graph, custom_metric_fn: Callable):
42
+ """
43
+ Args:
44
+ graph: input graph.
45
+ custom_metric_fn: custom metric function, that accepts the model as input and return float scalar metric.
46
+ """
47
+ self.all_interest_points = [n.node for n in graph.get_outputs()]
48
+ self.metric_fn = custom_metric_fn
49
+
50
+ def compute(self, mp_model: Any) -> float:
51
+ """ Compute the metric for the given model. """
52
+ sensitivity_metric = self.metric_fn(mp_model)
53
+ if not isinstance(sensitivity_metric, (float, np.floating)):
54
+ raise TypeError(
55
+ f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
56
+ return sensitivity_metric
57
+
58
+
59
+ class DistanceMetricCalculator(MetricCalculator):
60
+ """ Calculator for distance-based metrics. """
40
61
  def __init__(self,
41
62
  graph: Graph,
42
- quant_config: MixedPrecisionQuantizationConfig,
63
+ mp_config: MixedPrecisionQuantizationConfig,
43
64
  representative_data_gen: Callable,
44
65
  fw_info: FrameworkInfo,
45
66
  fw_impl: Any,
46
- disable_activation_for_metric: bool = False,
47
- hessian_info_service: HessianInfoService = None
48
- ):
67
+ hessian_info_service: HessianInfoService = None):
49
68
  """
50
69
  Initiates all relevant objects to manage a sensitivity evaluation for MP search.
51
70
  Create an object that allows to compute the sensitivity metric of an MP model (the sensitivity
@@ -59,23 +78,21 @@ class SensitivityEvaluation:
59
78
 
60
79
  Args:
61
80
  graph: Graph to search for its MP configuration.
81
+ mp_config: MP Quantization configuration for how the graph should be quantized.
62
82
  fw_info: FrameworkInfo object about the specific framework
63
83
  (e.g., attributes of different layers' weights to quantize).
64
- quant_config: MP Quantization configuration for how the graph should be quantized.
65
- representative_data_gen: Dataset used for getting batches for inference.
66
84
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
67
- disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
85
+ representative_data_gen: Dataset used for getting batches for inference.
68
86
  hessian_info_service: HessianInfoService to fetch Hessian approximation information.
69
-
70
87
  """
71
88
  self.graph = graph
72
- self.quant_config = quant_config
89
+ self.mp_config = mp_config
73
90
  self.representative_data_gen = representative_data_gen
74
91
  self.fw_info = fw_info
75
92
  self.fw_impl = fw_impl
76
- self.disable_activation_for_metric = disable_activation_for_metric
77
- if self.quant_config.use_hessian_based_scores:
78
- if not isinstance(hessian_info_service, HessianInfoService): # pragma: no cover
93
+
94
+ if self.mp_config.use_hessian_based_scores:
95
+ if not isinstance(hessian_info_service, HessianInfoService): # pragma: no cover
79
96
  Logger.critical(
80
97
  f"When using Hessian-based approximations for sensitivity evaluation, a valid HessianInfoService object is required; found {type(hessian_info_service)}.")
81
98
  self.hessian_info_service = hessian_info_service
@@ -84,42 +101,35 @@ class SensitivityEvaluation:
84
101
 
85
102
  # Get interest points and output points set for distance measurement and set other helper datasets
86
103
  # We define a separate set of output nodes of the model for the purpose of sensitivity computation.
87
- self.interest_points = get_mp_interest_points(graph,
88
- fw_impl.count_node_for_mixed_precision_interest_points,
89
- quant_config.num_interest_points_factor)
90
- # If using a custom metric - return only model outputs
91
- if self.quant_config.custom_metric_fn is not None:
92
- self.interest_points = []
104
+ self.interest_points = self.get_mp_interest_points(graph,
105
+ fw_impl.count_node_for_mixed_precision_interest_points,
106
+ mp_config.num_interest_points_factor)
93
107
 
94
108
  # We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
95
109
  # because hessian weights already do normalization.
96
- use_normalized_mse = self.quant_config.use_hessian_based_scores is False
97
- self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)
98
-
99
- self.output_points = get_output_nodes_for_metric(graph)
100
- # If using a custom metric - return all model outputs
101
- if self.quant_config.custom_metric_fn is not None:
102
- self.output_points = [n.node for n in graph.get_outputs()]
103
- self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
110
+ use_normalized_mse = self.mp_config.use_hessian_based_scores is False
111
+ self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points,
112
+ use_normalized_mse)
113
+
114
+ output_points = self.get_output_nodes_for_metric(graph)
115
+ self.all_interest_points = self.interest_points + output_points
116
+ self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(output_points,
104
117
  use_normalized_mse)
105
118
 
119
+ self.ref_model, _ = fw_impl.model_builder(graph, mode=ModelBuilderMode.FLOAT,
120
+ append2output=self.all_interest_points)
121
+
106
122
  # Setting lists with relative position of the interest points
107
123
  # and output points in the list of all mp model activation tensors
108
124
  graph_sorted_nodes = self.graph.get_topo_sorted_nodes()
109
- all_out_tensors_indices = [graph_sorted_nodes.index(n) for n in self.interest_points + self.output_points]
125
+ all_out_tensors_indices = [graph_sorted_nodes.index(n) for n in self.all_interest_points]
110
126
  global_ipts_indices = [graph_sorted_nodes.index(n) for n in self.interest_points]
111
- global_out_pts_indices = [graph_sorted_nodes.index(n) for n in self.output_points]
127
+ global_out_pts_indices = [graph_sorted_nodes.index(n) for n in output_points]
112
128
  self.ips_act_indices = [all_out_tensors_indices.index(i) for i in global_ipts_indices]
113
129
  self.out_ps_act_indices = [all_out_tensors_indices.index(i) for i in global_out_pts_indices]
114
130
 
115
- # Build a mixed-precision model which can be configured to use different bitwidth in different layers.
116
- # And a baseline model.
117
- # Also, returns a mapping between a configurable graph's node and its matching layer(s)
118
- # in the new built MP model.
119
- self.baseline_model, self.model_mp, self.conf_node2layers = self._build_models()
120
-
121
131
  # Build images batches for inference comparison and cat to framework type
122
- images_batches = self._get_images_batches(quant_config.num_of_images)
132
+ images_batches = self._get_images_batches(mp_config.num_of_images)
123
133
  self.images_batches = [self.fw_impl.to_tensor(img) for img in images_batches]
124
134
 
125
135
  # Initiating baseline_tensors_list since it is not initiated in SensitivityEvaluationManager init.
@@ -128,12 +138,28 @@ class SensitivityEvaluation:
128
138
  # Computing Hessian-based scores for weighted average distance metric computation (only if requested),
129
139
  # and assigning distance_weighting method accordingly.
130
140
  self.interest_points_hessians = None
131
- if self.quant_config.use_hessian_based_scores is True:
141
+ if self.mp_config.use_hessian_based_scores is True:
132
142
  self.interest_points_hessians = self._compute_hessian_based_scores()
133
- self.quant_config.distance_weighting_method = lambda d: self.interest_points_hessians
143
+ self.mp_config.distance_weighting_method = lambda d: self.interest_points_hessians
134
144
 
135
- def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = False) -> Tuple[
136
- List[Callable], List[int]]:
145
+ def compute(self, mp_model) -> float:
146
+ """
147
+ Compute the metric for the given model.
148
+
149
+ Args:
150
+ mp_model: MP configured model.
151
+
152
+ Returns:
153
+ Computed metric.
154
+ """
155
+ ipts_distances, out_pts_distances = self._compute_distance(mp_model)
156
+ sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
157
+ self.mp_config.distance_weighting_method)
158
+ return sensitivity_metric
159
+
160
+ def _init_metric_points_lists(self,
161
+ points: List[BaseNode],
162
+ norm_mse: bool = False) -> Tuple[List[Callable], List[int]]:
137
163
  """
138
164
  Initiates required lists for future use when computing the sensitivity metric.
139
165
  Each point on which the metric is computed uses a dedicated distance function based on its type.
@@ -150,101 +176,19 @@ class SensitivityEvaluation:
150
176
  axis_list = []
151
177
  for n in points:
152
178
  distance_fn, axis = self.fw_impl.get_mp_node_distance_fn(n,
153
- compute_distance_fn=self.quant_config.compute_distance_fn,
179
+ compute_distance_fn=self.mp_config.compute_distance_fn,
154
180
  norm_mse=norm_mse)
155
181
  distance_fns_list.append(distance_fn)
156
182
  # Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
157
183
  axis_list.append(axis if distance_fn == compute_kl_divergence else None)
158
184
  return distance_fns_list, axis_list
159
185
 
160
- def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
161
- """
162
- Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
163
- is computed based on the similarity of the interest points' outputs between the MP model
164
- and the float model or a custom metric if given).
165
- Quantization for any configurable activation / weight that were not passed is disabled.
166
-
167
- Args:
168
- mp_a_cfg: Bitwidth activations configuration for the MP model.
169
- mp_w_cfg: Bitwidth weights configuration for the MP model.
170
-
171
- Returns:
172
- The sensitivity metric of the MP model for a given configuration.
173
- """
174
-
175
- with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
176
- sensitivity_metric = self._compute_metric()
177
-
178
- return sensitivity_metric
179
-
180
- def _compute_metric(self) -> float:
181
- """
182
- Compute sensitivity metric on a configured mp model.
183
-
184
- Returns:
185
- Sensitivity metric.
186
- """
187
- if self.quant_config.custom_metric_fn:
188
- sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
189
- if not isinstance(sensitivity_metric, (float, np.floating)):
190
- raise TypeError(
191
- f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
192
- return sensitivity_metric
193
-
194
- # compute default metric
195
- ipts_distances, out_pts_distances = self._compute_distance()
196
- sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
197
- self.quant_config.distance_weighting_method)
198
- return sensitivity_metric
199
-
200
186
  def _init_baseline_tensors_list(self):
201
187
  """
202
188
  Evaluates the baseline model on all images and returns the obtained lists of tensors in a list for later use.
203
189
  """
204
- return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.baseline_model, images))
205
- for images in self.images_batches]
206
-
207
- def _build_models(self) -> Any:
208
- """
209
- Builds two models - an MP model with configurable layers and a baseline, float model.
210
-
211
- Returns: A tuple with two models built from the given graph: a baseline model (with baseline configuration) and
212
- an MP model (which can be configured for a specific bitwidth configuration).
213
- Note that the type of the returned models is dependent on the used framework (TF/Pytorch).
214
- """
215
-
216
- evaluation_graph = copy.deepcopy(self.graph)
217
-
218
- # Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
219
- # be added to the model).
220
- for n in evaluation_graph.get_topo_sorted_nodes():
221
- if self.disable_activation_for_metric or not n.has_configurable_activation():
222
- for c in n.candidates_quantization_cfg:
223
- c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
224
- if not n.has_any_configurable_weight():
225
- for c in n.candidates_quantization_cfg:
226
- c.weights_quantization_cfg.disable_all_weights_quantization()
227
-
228
- model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
229
- mode=ModelBuilderMode.MIXEDPRECISION,
230
- append2output=self.interest_points + self.output_points,
231
- fw_info=self.fw_info)
232
-
233
- # Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
234
- # Note: from this point mp_model is not in sync with graph quantization configuration for configurable nodes.
235
- for layer in itertools.chain(*conf_node2layers.values()):
236
- if isinstance(layer, self.fw_impl.activation_quant_layer_cls):
237
- set_activation_quant_layer_to_bitwidth(layer, None, self.fw_impl)
238
- else:
239
- assert isinstance(layer, self.fw_impl.weights_quant_layer_cls)
240
- set_weights_quant_layer_to_bitwidth(layer, None, self.fw_impl)
241
-
242
- # Build a baseline model (to compute distances from).
243
- baseline_model, _ = self.fw_impl.model_builder(evaluation_graph,
244
- mode=ModelBuilderMode.FLOAT,
245
- append2output=self.interest_points + self.output_points)
246
-
247
- return baseline_model, model_mp, conf_node2layers
190
+ return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.ref_model, images))
191
+ for images in self.images_batches]
248
192
 
249
193
  def _compute_hessian_based_scores(self) -> np.ndarray:
250
194
  """
@@ -257,61 +201,21 @@ class SensitivityEvaluation:
257
201
  # Create a request for Hessian approximation scores with specific configurations
258
202
  # (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
259
203
  fw_dataloader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen,
260
- batch_size=self.quant_config.hessian_batch_size)
204
+ batch_size=self.mp_config.hessian_batch_size)
261
205
  hessian_info_request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
262
206
  granularity=HessianScoresGranularity.PER_TENSOR,
263
207
  target_nodes=self.interest_points,
264
208
  data_loader=fw_dataloader,
265
- n_samples=self.quant_config.num_of_images)
209
+ n_samples=self.mp_config.num_of_images)
266
210
 
267
211
  # Fetch the Hessian approximation scores for the current interest point
268
212
  nodes_approximations = self.hessian_info_service.fetch_hessian(request=hessian_info_request)
269
- approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points], axis=1) # samples X nodes
213
+ approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points],
214
+ axis=1) # samples X nodes
270
215
 
271
216
  # Return the mean approximation value across all images for each interest point
272
217
  return np.mean(approx_by_image, axis=0)
273
218
 
274
- @contextlib.contextmanager
275
- def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
276
- """
277
- Context manager to configure specific configurable layers of the mp model. At exit, configuration is
278
- automatically restored to un-quantized.
279
-
280
- Args:
281
- mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
282
- mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
283
-
284
- """
285
- if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
286
- mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
287
- raise ValueError(f'Requested configuration is either empty or contain only None values.')
288
-
289
- # defined here so that it can't be used directly
290
- def apply_bitwidth_config(a_cfg, w_cfg):
291
- node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
292
- for n in node_names:
293
- node_quant_layers = self.conf_node2layers.get(n)
294
- if node_quant_layers is None: # pragma: no cover
295
- raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
296
- for qlayer in node_quant_layers:
297
- assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
298
- self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
299
- if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
300
- set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
301
- a_cfg.pop(n)
302
- elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
303
- set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
304
- w_cfg.pop(n)
305
- if a_cfg or w_cfg:
306
- raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
307
- f'weights config {w_cfg}.')
308
-
309
- apply_bitwidth_config(mp_a_cfg.copy(), mp_w_cfg.copy())
310
- try:
311
- yield
312
- finally:
313
- apply_bitwidth_config({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
314
-
315
219
  def _compute_points_distance(self,
316
220
  baseline_tensors: List[Any],
317
221
  mp_tensors: List[Any],
@@ -338,7 +242,7 @@ class SensitivityEvaluation:
338
242
 
339
243
  return np.asarray(distance_v)
340
244
 
341
- def _compute_distance(self) -> Tuple[np.ndarray, np.ndarray]:
245
+ def _compute_distance(self, mp_model) -> Tuple[np.ndarray, np.ndarray]:
342
246
  """
343
247
  Computing the interest points distance and the output points distance, and using them to build a
344
248
  unified distance vector.
@@ -352,7 +256,7 @@ class SensitivityEvaluation:
352
256
  # Compute the distance matrix for num_of_images images.
353
257
  for images, baseline_tensors in zip(self.images_batches, self.baseline_tensors_list):
354
258
  # when using model.predict(), it does not use the QuantizeWrapper functionality
355
- mp_tensors = self.fw_impl.sensitivity_eval_inference(self.model_mp, images)
259
+ mp_tensors = self.fw_impl.sensitivity_eval_inference(mp_model, images)
356
260
  mp_tensors = self.fw_impl.to_numpy(mp_tensors)
357
261
 
358
262
  # Compute distance: similarity between the baseline model to the float model
@@ -440,77 +344,78 @@ class SensitivityEvaluation:
440
344
  samples_count += batch_size
441
345
  else:
442
346
  if samples_count < num_of_images:
443
- Logger.warning(f'Not enough images in representative dataset to generate {num_of_images} data points, '
444
- f'only {samples_count} were generated')
347
+ Logger.warning(
348
+ f'Not enough images in representative dataset to generate {num_of_images} data points, '
349
+ f'only {samples_count} were generated')
445
350
  return images_batches
446
351
 
352
+ @classmethod
353
+ def get_mp_interest_points(cls, graph: Graph,
354
+ interest_points_classifier: Callable,
355
+ num_ip_factor: float) -> List[BaseNode]:
356
+ """
357
+ Gets a list of interest points for the mixed precision metric computation.
358
+ The list is constructed from a filtered set of nodes in the graph.
359
+ Note that the output layers are separated from the interest point set for metric computation purposes.
447
360
 
448
- def get_mp_interest_points(graph: Graph,
449
- interest_points_classifier: Callable,
450
- num_ip_factor: float) -> List[BaseNode]:
451
- """
452
- Gets a list of interest points for the mixed precision metric computation.
453
- The list is constructed from a filtered set of nodes in the graph.
454
- Note that the output layers are separated from the interest point set for metric computation purposes.
455
-
456
- Args:
457
- graph: Graph to search for its MP configuration.
458
- interest_points_classifier: A function that indicates whether a given node in considered as a potential
459
- interest point for mp metric computation purposes.
460
- num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
461
-
462
- Returns: A list of interest points (nodes in the graph).
463
-
464
- """
465
- sorted_nodes = graph.get_topo_sorted_nodes()
466
- ip_nodes = [n for n in sorted_nodes if interest_points_classifier(n)]
467
-
468
- interest_points_nodes = bound_num_interest_points(ip_nodes, num_ip_factor)
361
+ Args:
362
+ graph: Graph to search for its MP configuration.
363
+ interest_points_classifier: A function that indicates whether a given node in considered as a potential
364
+ interest point for mp metric computation purposes.
365
+ num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
469
366
 
470
- # We exclude output nodes from the set of interest points since they are used separately in the sensitivity evaluation.
471
- output_nodes = [n.node for n in graph.get_outputs()]
367
+ Returns: A list of interest points (nodes in the graph).
472
368
 
473
- interest_points = [n for n in interest_points_nodes if n not in output_nodes]
369
+ """
370
+ sorted_nodes = graph.get_topo_sorted_nodes()
371
+ ip_nodes = [n for n in sorted_nodes if interest_points_classifier(n)]
474
372
 
475
- return interest_points
373
+ interest_points_nodes = cls.bound_num_interest_points(ip_nodes, num_ip_factor)
476
374
 
375
+ # We exclude output nodes from the set of interest points since they are used separately in the sensitivity evaluation.
376
+ output_nodes = [n.node for n in graph.get_outputs()]
477
377
 
478
- def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]:
479
- """
480
- Returns a list of output nodes that are also quantized (either kernel weights attribute or activation)
481
- to be used as a set of output points in the distance metric computation.
378
+ interest_points = [n for n in interest_points_nodes if n not in output_nodes]
482
379
 
483
- Args:
484
- graph: Graph to search for its MP configuration.
380
+ return interest_points
485
381
 
486
- Returns: A list of output nodes.
382
+ @staticmethod
383
+ def get_output_nodes_for_metric(graph: Graph) -> List[BaseNode]:
384
+ """
385
+ Returns a list of output nodes that are also quantized (either kernel weights attribute or activation)
386
+ to be used as a set of output points in the distance metric computation.
487
387
 
488
- """
388
+ Args:
389
+ graph: Graph to search for its MP configuration.
489
390
 
490
- return [n.node for n in graph.get_outputs()
491
- if (graph.fw_info.is_kernel_op(n.node.type) and
492
- n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
493
- n.node.is_activation_quantization_enabled()]
391
+ Returns: A list of output nodes.
494
392
 
393
+ """
495
394
 
496
- def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
497
- """
498
- Filters the list of interest points and returns a shorter list with number of interest points smaller than some
499
- default threshold.
395
+ return [n.node for n in graph.get_outputs()
396
+ if (graph.fw_info.is_kernel_op(n.node.type) and
397
+ n.node.is_weights_quantization_enabled(graph.fw_info.get_kernel_op_attributes(n.node.type)[0])) or
398
+ n.node.is_activation_quantization_enabled()]
500
399
 
501
- Args:
502
- sorted_ip_list: List of nodes which are considered as interest points for the metric computation.
503
- num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
400
+ @staticmethod
401
+ def bound_num_interest_points(sorted_ip_list: List[BaseNode], num_ip_factor: float) -> List[BaseNode]:
402
+ """
403
+ Filters the list of interest points and returns a shorter list with number of interest points smaller than some
404
+ default threshold.
504
405
 
505
- Returns: A new list of interest points (list of nodes).
406
+ Args:
407
+ sorted_ip_list: List of nodes which are considered as interest points for the metric computation.
408
+ num_ip_factor: Percentage out of the total set of interest points that we want to actually use.
506
409
 
507
- """
508
- if num_ip_factor < 1.0:
509
- num_interest_points = int(num_ip_factor * len(sorted_ip_list))
510
- Logger.info(f'Using {num_interest_points} for mixed-precision metric evaluation out of total '
511
- f'{len(sorted_ip_list)} potential interest points.')
512
- # Take num_interest_points evenly spaced interest points from the original list
513
- indices = np.round(np.linspace(0, len(sorted_ip_list) - 1, num_interest_points)).astype(int)
514
- return [sorted_ip_list[i] for i in indices]
410
+ Returns: A new list of interest points (list of nodes).
515
411
 
516
- return sorted_ip_list
412
+ """
413
+ if num_ip_factor < 1.0:
414
+ num_interest_points = int(num_ip_factor * len(sorted_ip_list))
415
+ Logger.info(f'Using {num_interest_points} for mixed-precision metric evaluation out of total '
416
+ f'{len(sorted_ip_list)} potential interest points.')
417
+ # Take num_interest_points evenly spaced interest points from the original list
418
+ indices = np.round(np.linspace(0, len(sorted_ip_list) - 1, num_interest_points)).astype(int)
419
+ return [sorted_ip_list[i] for i in indices]
420
+
421
+ return sorted_ip_list
@@ -0,0 +1,168 @@
1
+ # Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import contextlib
16
+ import copy
17
+ import itertools
18
+
19
+ from typing import Callable, Any, Tuple, Dict, Optional
20
+
21
+ from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
22
+ from model_compression_toolkit.core.common import Graph
23
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.metric_calculators import \
24
+ CustomMetricCalculator, DistanceMetricCalculator
25
+ from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.set_layer_to_bitwidth import \
26
+ set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
27
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
28
+ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
29
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
30
+
31
+
32
+ class SensitivityEvaluation:
33
+ """
34
+ Sensitivity evaluation of a bit-width configuration for Mixed Precision search.
35
+ """
36
+
37
+ def __init__(self,
38
+ graph: Graph,
39
+ mp_config: MixedPrecisionQuantizationConfig,
40
+ representative_data_gen: Callable,
41
+ fw_info: FrameworkInfo,
42
+ fw_impl: Any,
43
+ disable_activation_for_metric: bool = False,
44
+ hessian_info_service: HessianInfoService = None
45
+ ):
46
+ """
47
+ Args:
48
+ graph: Graph to search for its MP configuration.
49
+ fw_info: FrameworkInfo object about the specific framework
50
+ (e.g., attributes of different layers' weights to quantize).
51
+ mp_config: MP Quantization configuration for how the graph should be quantized.
52
+ representative_data_gen: Dataset used for getting batches for inference.
53
+ fw_impl: FrameworkImplementation object with a specific framework methods implementation.
54
+ disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
55
+ hessian_info_service: HessianInfoService to fetch Hessian approximation information.
56
+
57
+ """
58
+ self.mp_config = mp_config
59
+ self.representative_data_gen = representative_data_gen
60
+ self.fw_info = fw_info
61
+ self.fw_impl = fw_impl
62
+
63
+ if self.mp_config.custom_metric_fn:
64
+ self.metric_calculator = CustomMetricCalculator(graph, self.mp_config.custom_metric_fn)
65
+ else:
66
+ self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
67
+ fw_info=fw_info, fw_impl=fw_impl,
68
+ hessian_info_service=hessian_info_service)
69
+
70
+ # Build a mixed-precision model which can be configured to use different bitwidth in different layers.
71
+ # Also, returns a mapping between a configurable graph's node and its matching layer(s) in the built MP model.
72
+ self.mp_model, self.conf_node2layers = self._build_mp_model(graph, self.metric_calculator.all_interest_points,
73
+ disable_activation_for_metric)
74
+
75
+ def compute_metric(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]) -> float:
76
+ """
77
+ Compute the sensitivity metric of the MP model for a given configuration.
78
+ Quantization for any configurable activation / weight that were not passed is disabled.
79
+
80
+ Args:
81
+ mp_a_cfg: Bitwidth activations configuration for the MP model.
82
+ mp_w_cfg: Bitwidth weights configuration for the MP model.
83
+
84
+ Returns:
85
+ The sensitivity metric of the MP model for a given configuration.
86
+ """
87
+ with self._configured_mp_model(mp_a_cfg, mp_w_cfg):
88
+ sensitivity_metric = self.metric_calculator.compute(self.mp_model)
89
+
90
+ return sensitivity_metric
91
+
92
+ def _build_mp_model(self, graph, outputs, disable_activations: bool) -> Tuple[Any, dict]:
93
+ """
94
+ Builds an MP model with configurable layers.
95
+
96
+ Returns:
97
+ MP model and a mapping from configurable graph nodes to their corresponding quantization layer(s)
98
+ in the MP model.
99
+ """
100
+ evaluation_graph = copy.deepcopy(graph)
101
+
102
+ # Disable quantization for non-configurable nodes, and, if requested, for all activations (quantizers won't
103
+ # be added to the model).
104
+ for n in evaluation_graph.get_topo_sorted_nodes():
105
+ if disable_activations or not n.has_configurable_activation():
106
+ for c in n.candidates_quantization_cfg:
107
+ c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
108
+ if not n.has_any_configurable_weight():
109
+ for c in n.candidates_quantization_cfg:
110
+ c.weights_quantization_cfg.disable_all_weights_quantization()
111
+
112
+ model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
113
+ mode=ModelBuilderMode.MIXEDPRECISION,
114
+ append2output=outputs,
115
+ fw_info=self.fw_info)
116
+
117
+ # Disable all configurable quantizers. They will be activated one at a time during sensitivity evaluation.
118
+ for layer in itertools.chain(*conf_node2layers.values()):
119
+ if isinstance(layer, self.fw_impl.activation_quant_layer_cls):
120
+ set_activation_quant_layer_to_bitwidth(layer, None, self.fw_impl)
121
+ else:
122
+ assert isinstance(layer, self.fw_impl.weights_quant_layer_cls)
123
+ set_weights_quant_layer_to_bitwidth(layer, None, self.fw_impl)
124
+
125
+ return model_mp, conf_node2layers
126
+
127
+ @contextlib.contextmanager
128
+ def _configured_mp_model(self, mp_a_cfg: Dict[str, Optional[int]], mp_w_cfg: Dict[str, Optional[int]]):
129
+ """
130
+ Context manager to configure specific configurable layers of the mp model. At exit, configuration is
131
+ automatically restored to un-quantized.
132
+
133
+ Args:
134
+ mp_a_cfg: Nodes bitwidth indices to configure activation quantizers to.
135
+ mp_w_cfg: Nodes bitwidth indices to configure weights quantizers to.
136
+
137
+ """
138
+ if not (mp_a_cfg and any(v is not None for v in mp_a_cfg.values()) or
139
+ mp_w_cfg and any(v is not None for v in mp_w_cfg.values())):
140
+ raise ValueError(f'Requested configuration is either empty or contain only None values.')
141
+
142
+ # defined here so that it can't be used directly
143
+ def apply_bitwidth_config(a_cfg, w_cfg):
144
+ node_names = set(a_cfg.keys()).union(set(w_cfg.keys()))
145
+ for n in node_names:
146
+ node_quant_layers = self.conf_node2layers.get(n)
147
+ if node_quant_layers is None: # pragma: no cover
148
+ raise ValueError(f"Matching layers for node {n} not found in the mixed precision model configuration.")
149
+ for qlayer in node_quant_layers:
150
+ assert isinstance(qlayer, (self.fw_impl.activation_quant_layer_cls,
151
+ self.fw_impl.weights_quant_layer_cls)), f'Unexpected {type(qlayer)} of node {n}'
152
+ if isinstance(qlayer, self.fw_impl.activation_quant_layer_cls) and n in a_cfg:
153
+ set_activation_quant_layer_to_bitwidth(qlayer, a_cfg[n], self.fw_impl)
154
+ a_cfg.pop(n)
155
+ elif isinstance(qlayer, self.fw_impl.weights_quant_layer_cls) and n in w_cfg:
156
+ set_weights_quant_layer_to_bitwidth(qlayer, w_cfg[n], self.fw_impl)
157
+ w_cfg.pop(n)
158
+ if a_cfg or w_cfg:
159
+ raise ValueError(f'Not all mp configs were consumed, remaining activation config {a_cfg}, '
160
+ f'weights config {w_cfg}.')
161
+
162
+ apply_bitwidth_config(mp_a_cfg.copy(), mp_w_cfg.copy())
163
+ try:
164
+ yield
165
+ finally:
166
+ apply_bitwidth_config({n: None for n in mp_a_cfg}, {n: None for n in mp_w_cfg})
167
+
168
+
@@ -42,13 +42,9 @@ def get_previous_node_with_activation_quantization(linear_node: BaseNode,
42
42
 
43
43
  prev_node = prev_nodes[0]
44
44
 
45
- activation_quantization_config = prev_node.final_activation_quantization_cfg
45
+ prev_quant_node = graph.retrieve_preserved_quantization_node(prev_node)
46
46
 
47
- # Search for node with activation quantization
48
- if activation_quantization_config.enable_activation_quantization:
49
- return prev_node
50
- else:
51
- return get_previous_node_with_activation_quantization(prev_node, graph)
47
+ return prev_quant_node if prev_quant_node.is_activation_quantization_enabled() else None
52
48
 
53
49
 
54
50
  def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray: