mct-nightly 2.1.0.20240617.451__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
- {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +38 -12
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
- model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
- model_compression_toolkit/xquant/__init__.py +19 -0
- model_compression_toolkit/xquant/common/__init__.py +15 -0
- model_compression_toolkit/xquant/common/constants.py +38 -0
- model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
- model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
- model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
- model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
- model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
- model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
- model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
- model_compression_toolkit/xquant/common/xquant_config.py +39 -0
- model_compression_toolkit/xquant/keras/__init__.py +15 -0
- model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
- model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
- model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
- model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
- model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
- model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
- model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
- model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
- {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
16
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
17
|
+
|
18
|
+
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
19
|
+
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
20
|
+
|
21
|
+
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
|
22
|
+
from typing import Any, Callable
|
23
|
+
|
24
|
+
from model_compression_toolkit.core.common import Graph
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
26
|
+
|
27
|
+
|
28
|
+
class ModelFoldingUtils:
|
29
|
+
"""
|
30
|
+
Utility class for handling model folding operations such as batch normalization (BN) folding,
|
31
|
+
residual collapsing, and other graph optimizations.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self,
|
35
|
+
fw_info: FrameworkInfo,
|
36
|
+
fw_impl: FrameworkImplementation,
|
37
|
+
fw_default_tpc: TargetPlatformCapabilities):
|
38
|
+
"""
|
39
|
+
Initialize the ModelFoldingUtils class with framework-specific information, implementation details,
|
40
|
+
and default TPC.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
fw_info: Framework-specific information.
|
44
|
+
fw_impl: Implementation functions for the framework.
|
45
|
+
fw_default_tpc: Default target platform capabilities for the handled framework.
|
46
|
+
"""
|
47
|
+
self.fw_info = fw_info
|
48
|
+
self.fw_impl = fw_impl
|
49
|
+
self.fw_default_tpc = fw_default_tpc
|
50
|
+
|
51
|
+
def create_float_folded_model(self, float_model: Any, representative_dataset: Any = None) -> Any:
|
52
|
+
"""
|
53
|
+
Create folded version of the model like MCT does (bn folding, residual collapsing, etc.).
|
54
|
+
This is needed since we need the models we compare to have the same architecture for
|
55
|
+
comparing tensors in different points of the models.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
float_model: The floating-point model to be folded.
|
59
|
+
representative_dataset: A callable for generating representative data.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
The folded floating-point model.
|
63
|
+
|
64
|
+
"""
|
65
|
+
float_graph = self.create_float_folded_graph(model=float_model,
|
66
|
+
repr_dataset=representative_dataset)
|
67
|
+
float_folded_model, _ = self.fw_impl.model_builder(
|
68
|
+
float_graph,
|
69
|
+
mode=ModelBuilderMode.FLOAT,
|
70
|
+
append2output=None,
|
71
|
+
fw_info=self.fw_info
|
72
|
+
)
|
73
|
+
return float_folded_model
|
74
|
+
|
75
|
+
def create_float_folded_graph(self, model: Any, repr_dataset: Callable) -> Graph:
|
76
|
+
"""
|
77
|
+
Create a folded graph for the float model. This process involves
|
78
|
+
graph optimizations similar to those applied during quantization (e.g., batch normalization folding,
|
79
|
+
residual collapsing).
|
80
|
+
|
81
|
+
Args:
|
82
|
+
model: The floating-point model to be folded into a graph.
|
83
|
+
repr_dataset: A callable that generates representative data.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The folded graph.
|
87
|
+
"""
|
88
|
+
# TODO:
|
89
|
+
# Consider simplifying graph_preparation_runner by extracting relevant parts to a separate method in MCT.
|
90
|
+
#
|
91
|
+
# Issues:
|
92
|
+
# 1. The quantization config affects how the optimized graph looks (e.g., collapsing).
|
93
|
+
# 2. The back2fw function requires quantization info even for float models.
|
94
|
+
#
|
95
|
+
# Future Considerations:
|
96
|
+
# - Remove quantization config parts related to graph optimizations.
|
97
|
+
# - Update back2fw to handle float models without needing quantization info.
|
98
|
+
graph = graph_preparation_runner(in_model=model,
|
99
|
+
representative_data_gen=repr_dataset,
|
100
|
+
fw_impl=self.fw_impl,
|
101
|
+
fw_info=self.fw_info,
|
102
|
+
quantization_config=DEFAULTCONFIG,
|
103
|
+
tpc=self.fw_default_tpc)
|
104
|
+
return graph
|
@@ -0,0 +1,194 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from functools import partial
|
16
|
+
|
17
|
+
from typing import Tuple, Any, Dict, Callable
|
18
|
+
|
19
|
+
from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
|
20
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
21
|
+
from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
|
22
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
23
|
+
from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
|
24
|
+
from model_compression_toolkit.logger import Logger
|
25
|
+
|
26
|
+
class SimilarityCalculator:
|
27
|
+
"""
|
28
|
+
A class to calculate the similarity between two models (that are often referred as float
|
29
|
+
and quantized models). It utilizes various utility classes for dataset preparation, model folding,
|
30
|
+
similarity computation, and model analysis.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self,
|
34
|
+
dataset_utils: DatasetUtils,
|
35
|
+
model_folding: ModelFoldingUtils,
|
36
|
+
similarity_functions: SimilarityFunctions,
|
37
|
+
model_analyzer_utils: ModelAnalyzer,
|
38
|
+
device: str = None):
|
39
|
+
"""
|
40
|
+
Initialize the SimilarityCalculator with required utilities.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
dataset_utils (DatasetUtils): Utility class for dataset preparation.
|
44
|
+
model_folding (ModelFoldingUtils): Utility class for model folding operations.
|
45
|
+
similarity_functions (SimilarityFunctions): Class containing similarity functions.
|
46
|
+
model_analyzer_utils (ModelAnalyzer): Utility class for model analysis.
|
47
|
+
device (str, optional): Device to perform computations on (e.g., 'cpu', 'cuda'). Defaults to None.
|
48
|
+
"""
|
49
|
+
self.dataset_utils = dataset_utils
|
50
|
+
self.model_folding = model_folding
|
51
|
+
self.similarity_functions = similarity_functions
|
52
|
+
self.model_analyzer_utils = model_analyzer_utils
|
53
|
+
self.device = device
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def compute_tensors_similarity(tensors_to_compare: Tuple[Any, Any],
|
57
|
+
similarity_metrics: Dict[str, Callable]) -> Dict[str, float]:
|
58
|
+
"""
|
59
|
+
Compute the similarity between two tensors using provided similarity metrics.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
tensors_to_compare (Tuple[Any, Any]): Tensors to compare by computing their similarity.
|
63
|
+
similarity_metrics (Dict[str, Callable]): A dictionary with similarity metric names and functions.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Dict[str, float]: A dictionary of similarity metric names and their computed values.
|
67
|
+
"""
|
68
|
+
x, y = tensors_to_compare
|
69
|
+
similarity_metrics = {k: v(x, y) for k, v in similarity_metrics.items()}
|
70
|
+
return similarity_metrics
|
71
|
+
|
72
|
+
def _get_float_to_quantized_compare_points(self,
|
73
|
+
quantized_model: Any,
|
74
|
+
float_model: Any) -> Dict[str, str]:
|
75
|
+
"""
|
76
|
+
Map corresponding layers between the float and quantized models for comparison.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
quantized_model (Any): The quantized model.
|
80
|
+
float_model (Any): The float model.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Dict[str, str]: A dictionary mapping float model layer names to quantized model layer names.
|
84
|
+
"""
|
85
|
+
# Identify the points in the quantized model to compare.
|
86
|
+
quant_points_names = self.model_analyzer_utils.identify_quantized_compare_points(quantized_model)
|
87
|
+
|
88
|
+
float_name2quant_name = {}
|
89
|
+
|
90
|
+
# Extract the names of the layers in the float model.
|
91
|
+
float_layers_names = self.model_analyzer_utils.extract_float_layer_names(float_model)
|
92
|
+
|
93
|
+
# Map each quantized layer to the corresponding float layer.
|
94
|
+
for quant_point in quant_points_names:
|
95
|
+
candidate_float_layer_name = self.model_analyzer_utils.find_corresponding_float_layer(
|
96
|
+
quant_compare_point=quant_point, quantized_model=quantized_model)
|
97
|
+
|
98
|
+
if candidate_float_layer_name in float_layers_names:
|
99
|
+
if candidate_float_layer_name not in float_name2quant_name:
|
100
|
+
float_name2quant_name[candidate_float_layer_name] = quant_point
|
101
|
+
else:
|
102
|
+
Logger.critical(f"Duplicate mapping found for layer: {candidate_float_layer_name}.")
|
103
|
+
else:
|
104
|
+
Logger.warning(
|
105
|
+
f"Could not find a matching layer in the float model for layer with name {quant_point}, "
|
106
|
+
f"skipping it in similarity metrics comparison points computation.")
|
107
|
+
|
108
|
+
return float_name2quant_name
|
109
|
+
|
110
|
+
def compute_similarity_metrics(self,
|
111
|
+
float_model: Any,
|
112
|
+
quantized_model: Any,
|
113
|
+
dataset: Callable,
|
114
|
+
custom_similarity_metrics: Dict[str, Callable] = None,
|
115
|
+
is_validation: bool = False) -> Tuple[Dict[str, float], Dict[str, Dict[str, float]]]:
|
116
|
+
"""
|
117
|
+
Compute the similarity metrics between the two models (usually, float and quantized models).
|
118
|
+
|
119
|
+
Args:
|
120
|
+
float_model (Any): The float model.
|
121
|
+
quantized_model (Any): The quantized model.
|
122
|
+
dataset (Callable): A callable to provide the dataset.
|
123
|
+
custom_similarity_metrics (Dict[str, Callable], optional): Custom similarity metrics. Defaults to None.
|
124
|
+
is_validation (bool, optional): Flag to indicate if the dataset is for validation. Defaults to False.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
Tuple[Dict[str, float], Dict[str, Dict[str, float]]]: Aggregated output similarity metrics and
|
128
|
+
intermediate similarity metrics for each layer.
|
129
|
+
"""
|
130
|
+
# Prepare the dataset such that the rest of operations are indistinguishable between the representative
|
131
|
+
# dataset and the validation dataset.
|
132
|
+
dataset = partial(self.dataset_utils.prepare_dataset,
|
133
|
+
dataset=dataset,
|
134
|
+
is_validation=is_validation,
|
135
|
+
device=self.device)
|
136
|
+
|
137
|
+
# Create a folded version of the float model.
|
138
|
+
float_model = self.model_folding.create_float_folded_model(float_model=float_model,
|
139
|
+
representative_dataset=dataset)
|
140
|
+
|
141
|
+
# Gather similarity metrics to compute (default and custom).
|
142
|
+
similarity_metrics_to_compute = self.similarity_functions.get_default_similarity_metrics()
|
143
|
+
if custom_similarity_metrics:
|
144
|
+
if not isinstance(custom_similarity_metrics, dict):
|
145
|
+
Logger.critical(
|
146
|
+
f"custom_similarity_metrics should be a dictionary but is of type "
|
147
|
+
f"{type(custom_similarity_metrics)}.")
|
148
|
+
similarity_metrics_to_compute.update(custom_similarity_metrics)
|
149
|
+
|
150
|
+
# Map float model layers to quantized model layers for comparison.
|
151
|
+
float_name2quant_name = self._get_float_to_quantized_compare_points(float_model=float_model,
|
152
|
+
quantized_model=quantized_model)
|
153
|
+
|
154
|
+
# Initialize dictionaries to store similarity metrics.
|
155
|
+
output_similarity_metrics = {key: [] for key in similarity_metrics_to_compute.keys()}
|
156
|
+
intermediate_similarity_metrics = {layer: {key: [] for key in similarity_metrics_to_compute.keys()} for layer in
|
157
|
+
float_name2quant_name.values()}
|
158
|
+
|
159
|
+
# Iterate over the dataset and compute similarity metrics.
|
160
|
+
for x in dataset():
|
161
|
+
# Extract activations and predictions from both models.
|
162
|
+
float_activations, quant_activations = (
|
163
|
+
self.model_analyzer_utils.extract_model_activations(
|
164
|
+
float_model, quantized_model, float_name2quant_name, x))
|
165
|
+
|
166
|
+
float_predictions = float_activations[MODEL_OUTPUT_KEY]
|
167
|
+
quant_predictions = quant_activations[MODEL_OUTPUT_KEY]
|
168
|
+
|
169
|
+
# Compute similarity metrics for the output predictions.
|
170
|
+
output_results = self.compute_tensors_similarity((float_predictions, quant_predictions),
|
171
|
+
similarity_metrics_to_compute)
|
172
|
+
for key in output_similarity_metrics:
|
173
|
+
output_similarity_metrics[key].append(output_results[key])
|
174
|
+
|
175
|
+
# Compute similarity metrics for each intermediate layer.
|
176
|
+
for float_layer, quant_layer in float_name2quant_name.items():
|
177
|
+
intermediate_results = self.compute_tensors_similarity(
|
178
|
+
(float_activations[float_layer], quant_activations[quant_layer]),
|
179
|
+
similarity_metrics_to_compute)
|
180
|
+
for key in intermediate_similarity_metrics[quant_layer]:
|
181
|
+
intermediate_similarity_metrics[quant_layer][key].append(intermediate_results[key])
|
182
|
+
|
183
|
+
# Aggregate the output similarity metrics.
|
184
|
+
aggregated_output_similarity_metrics = {key: sum(value) / len(value) for key, value in
|
185
|
+
output_similarity_metrics.items()}
|
186
|
+
|
187
|
+
# Aggregate the intermediate similarity metrics for each layer.
|
188
|
+
for layer_name, layer_similarity_metrics in intermediate_similarity_metrics.items():
|
189
|
+
for similarity_name, similarity_values_list in layer_similarity_metrics.items():
|
190
|
+
if len(similarity_values_list) == 0:
|
191
|
+
Logger.critical(f"Can not average similarities of an empty list.")
|
192
|
+
intermediate_similarity_metrics[layer_name][similarity_name] = sum(similarity_values_list) / len(similarity_values_list)
|
193
|
+
|
194
|
+
return aggregated_output_similarity_metrics, intermediate_similarity_metrics
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Any, Dict, Callable
|
17
|
+
|
18
|
+
from model_compression_toolkit.xquant.common.constants import CS_SIMILARITY_METRIC_NAME, SQNR_SIMILARITY_METRIC_NAME, MSE_SIMILARITY_METRIC_NAME
|
19
|
+
|
20
|
+
DEFAULT_SIMILARITY_METRICS_NAMES = [CS_SIMILARITY_METRIC_NAME, MSE_SIMILARITY_METRIC_NAME, SQNR_SIMILARITY_METRIC_NAME]
|
21
|
+
|
22
|
+
class SimilarityFunctions:
|
23
|
+
"""
|
24
|
+
A class that provides various static methods to compute similarity metrics between tensors.
|
25
|
+
"""
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def compute_mse(x: Any, y: Any) -> float:
|
29
|
+
"""
|
30
|
+
Compute the Mean Squared Error (MSE) between two tensors (usually, the float and quantized predictions).
|
31
|
+
|
32
|
+
Args:
|
33
|
+
x (Any): First tensor to compare.
|
34
|
+
y (Any): Second tensor to compare.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
float: The computed MSE value.
|
38
|
+
"""
|
39
|
+
raise NotImplemented # pragma: no cover
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def compute_cs(x: Any, y: Any) -> float:
|
43
|
+
"""
|
44
|
+
Compute the Cosine Similarity (CS) between two tensors (usually, the float and quantized predictions).
|
45
|
+
|
46
|
+
Args:
|
47
|
+
x (Any): First tensor to compare.
|
48
|
+
y (Any): Second tensor to compare.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
float: The computed CS value.
|
52
|
+
"""
|
53
|
+
raise NotImplemented # pragma: no cover
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def compute_sqnr(x: Any, y: Any) -> float:
|
57
|
+
"""
|
58
|
+
Compute the Signal-to-Quantization-Noise Ratio (SQNR) between two tensors (usually, the float and quantized predictions).
|
59
|
+
|
60
|
+
Args:
|
61
|
+
x (Any): First tensor to compare.
|
62
|
+
y (Any): Second tensor to compare.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
float: The computed SQNR value.
|
66
|
+
"""
|
67
|
+
raise NotImplemented # pragma: no cover
|
68
|
+
|
69
|
+
def get_default_similarity_metrics(self) -> Dict[str, Callable]:
|
70
|
+
"""
|
71
|
+
Get the default similarity metrics to compute.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Dict[str, Callable]: A dictionary where the keys are similarity metric names and the values are the corresponding functions.
|
75
|
+
"""
|
76
|
+
return {
|
77
|
+
MSE_SIMILARITY_METRIC_NAME: self.compute_mse,
|
78
|
+
CS_SIMILARITY_METRIC_NAME: self.compute_cs,
|
79
|
+
SQNR_SIMILARITY_METRIC_NAME: self.compute_sqnr
|
80
|
+
}
|
81
|
+
|
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.core.common import Graph
|
17
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
18
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
19
|
+
|
20
|
+
|
21
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
22
|
+
from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG
|
23
|
+
from model_compression_toolkit.logger import Logger
|
24
|
+
|
25
|
+
|
26
|
+
from typing import Any, Dict, Callable
|
27
|
+
|
28
|
+
|
29
|
+
class TensorboardUtils:
|
30
|
+
"""
|
31
|
+
Utility class for handling Tensorboard operations like adding graph to display
|
32
|
+
and histograms on the float model.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self,
|
36
|
+
report_dir: str,
|
37
|
+
fw_info: FrameworkInfo,
|
38
|
+
fw_impl: FrameworkImplementation):
|
39
|
+
"""
|
40
|
+
Initialize the TensorboardUtils.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
report_dir (str): Directory where Tensorboard logs will be stored.
|
44
|
+
fw_info (FrameworkInfo): Framework-specific information.
|
45
|
+
fw_impl (FrameworkImplementation): Framework-specific implementation.
|
46
|
+
"""
|
47
|
+
self.fw_impl = fw_impl
|
48
|
+
self.fw_info = fw_info
|
49
|
+
self.tb_writer = TensorboardWriter(report_dir, fw_info)
|
50
|
+
Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
|
51
|
+
|
52
|
+
def get_graph_for_tensorboard_display(self,
|
53
|
+
quantized_model: Any,
|
54
|
+
similarity_metrics: Dict[str, Any],
|
55
|
+
repr_dataset: Callable) -> Graph:
|
56
|
+
"""
|
57
|
+
Get the graph for Tensorboard display. The framework-specific implementations
|
58
|
+
(like KerasTensorboardUtils and PytorchTensorboardUtils) should implement this
|
59
|
+
as it differs between them when combining the similarity metrics into the graph.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
quantized_model (Any): The quantized model.
|
63
|
+
similarity_metrics (Dict[str, Any]): Metrics for model similarity.
|
64
|
+
repr_dataset (Callable): Representative dataset function.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Graph: The generated graph for Tensorboard display.
|
68
|
+
"""
|
69
|
+
Logger.critical("This method should be implemented by the framework-specific TensorboardUtils.") # pragma: no cover
|
70
|
+
|
71
|
+
def add_histograms_to_tensorboard(self,
|
72
|
+
graph: Graph):
|
73
|
+
"""
|
74
|
+
Add histograms to Tensorboard from a graph that holds these statistics.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
graph (Graph): Graph with histograms to add to the tensorboard.
|
78
|
+
"""
|
79
|
+
self.tb_writer.add_histograms(graph, TENSORBOARD_DEFAULT_TAG)
|
80
|
+
|
81
|
+
def add_graph_to_tensorboard(self,
|
82
|
+
quantized_model: Any,
|
83
|
+
similarity_metrics: Dict[str, Any],
|
84
|
+
repr_dataset: Callable):
|
85
|
+
"""
|
86
|
+
Add a graph to Tensorboard. The graph represents the quantized graph
|
87
|
+
with the similarity metrics that were measured in different nodes.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
quantized_model (Any): The quantized model.
|
91
|
+
similarity_metrics (Dict[str, Any]): The similarity metrics that were collected.
|
92
|
+
repr_dataset (Callable): Representative dataset to use (if needed, like in pytorch case).
|
93
|
+
"""
|
94
|
+
# Generate the quantized graph with similarity metrics.
|
95
|
+
tb_graph = self.get_graph_for_tensorboard_display(quantized_model=quantized_model,
|
96
|
+
similarity_metrics=similarity_metrics,
|
97
|
+
repr_dataset=repr_dataset)
|
98
|
+
|
99
|
+
self.tb_writer.add_graph(tb_graph, TENSORBOARD_DEFAULT_TAG)
|
100
|
+
|
101
|
+
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Dict, Callable
|
17
|
+
|
18
|
+
|
19
|
+
class XQuantConfig:
|
20
|
+
"""
|
21
|
+
Configuration for generating the report.
|
22
|
+
It allows to set the log dir that the report will be saved in and to add similarity metrics
|
23
|
+
to measure between tensors of the two models.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self,
|
27
|
+
report_dir: str,
|
28
|
+
custom_similarity_metrics: Dict[str, Callable] = None):
|
29
|
+
"""
|
30
|
+
Initializes the configuration for explainable quantization.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
report_dir (str): Directory where the reports will be saved.
|
34
|
+
custom_similarity_metrics (Dict[str, Callable]): Custom similarity metrics to be computed between tensors
|
35
|
+
of the two models. The dictionary keys are similarity metric names and the values are callables that implement the
|
36
|
+
similarity metric computation.
|
37
|
+
"""
|
38
|
+
self.report_dir = report_dir
|
39
|
+
self.custom_similarity_metrics = custom_similarity_metrics
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
19
|
+
|
20
|
+
|
21
|
+
class KerasDatasetUtils(DatasetUtils):
|
22
|
+
"""
|
23
|
+
Class with helpful methods for handling different kinds of Keras datasets from the user.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def prepare_dataset(dataset: Any, is_validation: bool, device: str = None):
|
28
|
+
"""
|
29
|
+
Prepare the dataset so calling it will return only inputs for the model (like in the case
|
30
|
+
of the representative dataset). For example, when the validation dataset is used, the labels
|
31
|
+
should be removed.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
dataset: Dataset to prepare.
|
35
|
+
is_validation: Whether it's validation dataset or not.
|
36
|
+
device: Device to transfer the data to.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
Generator to use for retrieving the dataset inputs.
|
40
|
+
"""
|
41
|
+
def process_data(x: Any, is_validation: bool):
|
42
|
+
"""
|
43
|
+
Processes individual data samples to transfer them to the device and convert to torch tensors if needed.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
data: The data sample to process.
|
47
|
+
is_validation: A flag indicating if this is a validation dataset.
|
48
|
+
device: The device to transfer the data to.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The data as torch tensors on the desired device.
|
52
|
+
"""
|
53
|
+
return x[0] if is_validation else x # Assume data[0] contains the inputs and data[1] the labels
|
54
|
+
|
55
|
+
for x in dataset():
|
56
|
+
yield process_data(x, is_validation)
|
57
|
+
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable, Dict, Any
|
17
|
+
|
18
|
+
from model_compression_toolkit.constants import FOUND_TF
|
19
|
+
from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
|
20
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
21
|
+
from model_compression_toolkit.logger import Logger
|
22
|
+
|
23
|
+
if FOUND_TF:
|
24
|
+
import keras
|
25
|
+
from model_compression_toolkit.xquant.keras.keras_report_utils import KerasReportUtils
|
26
|
+
|
27
|
+
def xquant_report_keras_experimental(float_model: keras.Model,
|
28
|
+
quantized_model: keras.Model,
|
29
|
+
repr_dataset: Callable,
|
30
|
+
validation_dataset: Callable,
|
31
|
+
xquant_config: XQuantConfig) -> Dict[str, Any]:
|
32
|
+
"""
|
33
|
+
Generate an explainable quantization report for a quantized Keras model.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
float_model (keras.Model): The original floating-point Keras model.
|
37
|
+
quantized_model (keras.Model): The quantized Keras model.
|
38
|
+
repr_dataset (Callable): The representative dataset used during quantization for similarity metrics computation.
|
39
|
+
validation_dataset (Callable): The validation dataset used for evaluation for similarity metrics computation.
|
40
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
44
|
+
"""
|
45
|
+
# Initialize the logger with the report directory.
|
46
|
+
Logger.set_log_file(log_folder=xquant_config.report_dir)
|
47
|
+
|
48
|
+
# Initialize a utility class for handling Keras-specific reporting tasks.
|
49
|
+
keras_report_utils = KerasReportUtils(xquant_config.report_dir)
|
50
|
+
|
51
|
+
# Create the report after collecting useful data like histograms and similarity metrics.
|
52
|
+
_collected_data = core_report_generator(float_model=float_model,
|
53
|
+
quantized_model=quantized_model,
|
54
|
+
repr_dataset=repr_dataset,
|
55
|
+
validation_dataset=validation_dataset,
|
56
|
+
fw_report_utils=keras_report_utils,
|
57
|
+
xquant_config=xquant_config)
|
58
|
+
|
59
|
+
return _collected_data
|
60
|
+
else:
|
61
|
+
def xquant_report_keras_experimental(*args, **kwargs):
|
62
|
+
Logger.critical("Tensorflow must be installed to use xquant_report_keras_experimental. "
|
63
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|