mct-nightly 2.1.0.20240616.65727__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +43 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/functional_node.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
- model_compression_toolkit/core/pytorch/constants.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
- model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
- model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
- model_compression_toolkit/xquant/__init__.py +19 -0
- model_compression_toolkit/xquant/common/__init__.py +15 -0
- model_compression_toolkit/xquant/common/constants.py +38 -0
- model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
- model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
- model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
- model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
- model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
- model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
- model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
- model_compression_toolkit/xquant/common/xquant_config.py +39 -0
- model_compression_toolkit/xquant/keras/__init__.py +15 -0
- model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
- model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
- model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
- model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
- model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
- model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
- model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
- model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.core.common import Graph
|
17
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
18
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
19
|
+
|
20
|
+
|
21
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
22
|
+
from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG
|
23
|
+
from model_compression_toolkit.logger import Logger
|
24
|
+
|
25
|
+
|
26
|
+
from typing import Any, Dict, Callable
|
27
|
+
|
28
|
+
|
29
|
+
class TensorboardUtils:
|
30
|
+
"""
|
31
|
+
Utility class for handling Tensorboard operations like adding graph to display
|
32
|
+
and histograms on the float model.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self,
|
36
|
+
report_dir: str,
|
37
|
+
fw_info: FrameworkInfo,
|
38
|
+
fw_impl: FrameworkImplementation):
|
39
|
+
"""
|
40
|
+
Initialize the TensorboardUtils.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
report_dir (str): Directory where Tensorboard logs will be stored.
|
44
|
+
fw_info (FrameworkInfo): Framework-specific information.
|
45
|
+
fw_impl (FrameworkImplementation): Framework-specific implementation.
|
46
|
+
"""
|
47
|
+
self.fw_impl = fw_impl
|
48
|
+
self.fw_info = fw_info
|
49
|
+
self.tb_writer = TensorboardWriter(report_dir, fw_info)
|
50
|
+
Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
|
51
|
+
|
52
|
+
def get_graph_for_tensorboard_display(self,
|
53
|
+
quantized_model: Any,
|
54
|
+
similarity_metrics: Dict[str, Any],
|
55
|
+
repr_dataset: Callable) -> Graph:
|
56
|
+
"""
|
57
|
+
Get the graph for Tensorboard display. The framework-specific implementations
|
58
|
+
(like KerasTensorboardUtils and PytorchTensorboardUtils) should implement this
|
59
|
+
as it differs between them when combining the similarity metrics into the graph.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
quantized_model (Any): The quantized model.
|
63
|
+
similarity_metrics (Dict[str, Any]): Metrics for model similarity.
|
64
|
+
repr_dataset (Callable): Representative dataset function.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Graph: The generated graph for Tensorboard display.
|
68
|
+
"""
|
69
|
+
Logger.critical("This method should be implemented by the framework-specific TensorboardUtils.") # pragma: no cover
|
70
|
+
|
71
|
+
def add_histograms_to_tensorboard(self,
|
72
|
+
graph: Graph):
|
73
|
+
"""
|
74
|
+
Add histograms to Tensorboard from a graph that holds these statistics.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
graph (Graph): Graph with histograms to add to the tensorboard.
|
78
|
+
"""
|
79
|
+
self.tb_writer.add_histograms(graph, TENSORBOARD_DEFAULT_TAG)
|
80
|
+
|
81
|
+
def add_graph_to_tensorboard(self,
|
82
|
+
quantized_model: Any,
|
83
|
+
similarity_metrics: Dict[str, Any],
|
84
|
+
repr_dataset: Callable):
|
85
|
+
"""
|
86
|
+
Add a graph to Tensorboard. The graph represents the quantized graph
|
87
|
+
with the similarity metrics that were measured in different nodes.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
quantized_model (Any): The quantized model.
|
91
|
+
similarity_metrics (Dict[str, Any]): The similarity metrics that were collected.
|
92
|
+
repr_dataset (Callable): Representative dataset to use (if needed, like in pytorch case).
|
93
|
+
"""
|
94
|
+
# Generate the quantized graph with similarity metrics.
|
95
|
+
tb_graph = self.get_graph_for_tensorboard_display(quantized_model=quantized_model,
|
96
|
+
similarity_metrics=similarity_metrics,
|
97
|
+
repr_dataset=repr_dataset)
|
98
|
+
|
99
|
+
self.tb_writer.add_graph(tb_graph, TENSORBOARD_DEFAULT_TAG)
|
100
|
+
|
101
|
+
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Dict, Callable
|
17
|
+
|
18
|
+
|
19
|
+
class XQuantConfig:
|
20
|
+
"""
|
21
|
+
Configuration for generating the report.
|
22
|
+
It allows to set the log dir that the report will be saved in and to add similarity metrics
|
23
|
+
to measure between tensors of the two models.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(self,
|
27
|
+
report_dir: str,
|
28
|
+
custom_similarity_metrics: Dict[str, Callable] = None):
|
29
|
+
"""
|
30
|
+
Initializes the configuration for explainable quantization.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
report_dir (str): Directory where the reports will be saved.
|
34
|
+
custom_similarity_metrics (Dict[str, Callable]): Custom similarity metrics to be computed between tensors
|
35
|
+
of the two models. The dictionary keys are similarity metric names and the values are callables that implement the
|
36
|
+
similarity metric computation.
|
37
|
+
"""
|
38
|
+
self.report_dir = report_dir
|
39
|
+
self.custom_similarity_metrics = custom_similarity_metrics
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
19
|
+
|
20
|
+
|
21
|
+
class KerasDatasetUtils(DatasetUtils):
|
22
|
+
"""
|
23
|
+
Class with helpful methods for handling different kinds of Keras datasets from the user.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def prepare_dataset(dataset: Any, is_validation: bool, device: str = None):
|
28
|
+
"""
|
29
|
+
Prepare the dataset so calling it will return only inputs for the model (like in the case
|
30
|
+
of the representative dataset). For example, when the validation dataset is used, the labels
|
31
|
+
should be removed.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
dataset: Dataset to prepare.
|
35
|
+
is_validation: Whether it's validation dataset or not.
|
36
|
+
device: Device to transfer the data to.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
Generator to use for retrieving the dataset inputs.
|
40
|
+
"""
|
41
|
+
def process_data(x: Any, is_validation: bool):
|
42
|
+
"""
|
43
|
+
Processes individual data samples to transfer them to the device and convert to torch tensors if needed.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
data: The data sample to process.
|
47
|
+
is_validation: A flag indicating if this is a validation dataset.
|
48
|
+
device: The device to transfer the data to.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The data as torch tensors on the desired device.
|
52
|
+
"""
|
53
|
+
return x[0] if is_validation else x # Assume data[0] contains the inputs and data[1] the labels
|
54
|
+
|
55
|
+
for x in dataset():
|
56
|
+
yield process_data(x, is_validation)
|
57
|
+
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable, Dict, Any
|
17
|
+
|
18
|
+
from model_compression_toolkit.constants import FOUND_TF
|
19
|
+
from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
|
20
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
21
|
+
from model_compression_toolkit.logger import Logger
|
22
|
+
|
23
|
+
if FOUND_TF:
|
24
|
+
import keras
|
25
|
+
from model_compression_toolkit.xquant.keras.keras_report_utils import KerasReportUtils
|
26
|
+
|
27
|
+
def xquant_report_keras_experimental(float_model: keras.Model,
|
28
|
+
quantized_model: keras.Model,
|
29
|
+
repr_dataset: Callable,
|
30
|
+
validation_dataset: Callable,
|
31
|
+
xquant_config: XQuantConfig) -> Dict[str, Any]:
|
32
|
+
"""
|
33
|
+
Generate an explainable quantization report for a quantized Keras model.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
float_model (keras.Model): The original floating-point Keras model.
|
37
|
+
quantized_model (keras.Model): The quantized Keras model.
|
38
|
+
repr_dataset (Callable): The representative dataset used during quantization for similarity metrics computation.
|
39
|
+
validation_dataset (Callable): The validation dataset used for evaluation for similarity metrics computation.
|
40
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
44
|
+
"""
|
45
|
+
# Initialize the logger with the report directory.
|
46
|
+
Logger.set_log_file(log_folder=xquant_config.report_dir)
|
47
|
+
|
48
|
+
# Initialize a utility class for handling Keras-specific reporting tasks.
|
49
|
+
keras_report_utils = KerasReportUtils(xquant_config.report_dir)
|
50
|
+
|
51
|
+
# Create the report after collecting useful data like histograms and similarity metrics.
|
52
|
+
_collected_data = core_report_generator(float_model=float_model,
|
53
|
+
quantized_model=quantized_model,
|
54
|
+
repr_dataset=repr_dataset,
|
55
|
+
validation_dataset=validation_dataset,
|
56
|
+
fw_report_utils=keras_report_utils,
|
57
|
+
xquant_config=xquant_config)
|
58
|
+
|
59
|
+
return _collected_data
|
60
|
+
else:
|
61
|
+
def xquant_report_keras_experimental(*args, **kwargs):
|
62
|
+
Logger.critical("Tensorflow must be installed to use xquant_report_keras_experimental. "
|
63
|
+
"The 'tensorflow' package is missing.") # pragma: no cover
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
18
|
+
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
19
|
+
from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
|
20
|
+
from model_compression_toolkit.ptq.keras.quantization_facade import DEFAULT_KERAS_TPC
|
21
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
22
|
+
from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
|
23
|
+
from model_compression_toolkit.xquant.keras.dataset_utils import KerasDatasetUtils
|
24
|
+
from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnalyzer
|
25
|
+
|
26
|
+
from model_compression_toolkit.xquant.keras.similarity_functions import KerasSimilarityFunctions
|
27
|
+
from model_compression_toolkit.xquant.keras.tensorboard_utils import KerasTensorboardUtils
|
28
|
+
|
29
|
+
|
30
|
+
class KerasReportUtils(FrameworkReportUtils):
|
31
|
+
"""
|
32
|
+
Class with various utility components required for generating the report for a Keras model.
|
33
|
+
"""
|
34
|
+
def __init__(self, report_dir: str):
|
35
|
+
"""
|
36
|
+
Args:
|
37
|
+
report_dir: Logging dir path.
|
38
|
+
"""
|
39
|
+
fw_info = DEFAULT_KERAS_INFO
|
40
|
+
fw_impl = KerasImplementation()
|
41
|
+
|
42
|
+
dataset_utils = KerasDatasetUtils()
|
43
|
+
model_folding = ModelFoldingUtils(fw_info=fw_info,
|
44
|
+
fw_impl=fw_impl,
|
45
|
+
fw_default_tpc=DEFAULT_KERAS_TPC)
|
46
|
+
|
47
|
+
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
48
|
+
model_folding=model_folding,
|
49
|
+
similarity_functions=KerasSimilarityFunctions(),
|
50
|
+
model_analyzer_utils=KerasModelAnalyzer())
|
51
|
+
|
52
|
+
tb_utils = KerasTensorboardUtils(report_dir=report_dir,
|
53
|
+
fw_impl=fw_impl,
|
54
|
+
fw_info=fw_info)
|
55
|
+
super().__init__(fw_info,
|
56
|
+
fw_impl,
|
57
|
+
similarity_calculator,
|
58
|
+
dataset_utils,
|
59
|
+
model_folding,
|
60
|
+
tb_utils)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import List, Tuple, Dict
|
16
|
+
|
17
|
+
from mct_quantizers import KerasQuantizationWrapper
|
18
|
+
from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
|
19
|
+
from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
|
20
|
+
import keras
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
|
24
|
+
class KerasModelAnalyzer(ModelAnalyzer):
|
25
|
+
"""
|
26
|
+
This class provides utilities for analyzing Keras models, specifically for
|
27
|
+
extracting activations and comparing float and quantized models.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def extract_model_activations(self,
|
31
|
+
float_model: keras.Model,
|
32
|
+
quantized_model: keras.Model,
|
33
|
+
float_name2quant_name: Dict[str, str],
|
34
|
+
data: List[np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
|
35
|
+
"""
|
36
|
+
Extracts activations from both the float and quantized models.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
float_model (keras.Model): The float model.
|
40
|
+
quantized_model (keras.Model): The quantized model.
|
41
|
+
float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
|
42
|
+
names.
|
43
|
+
data (List[np.ndarray]): Input data for which to compute activations.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
|
47
|
+
- Dictionary of activations for the float model.
|
48
|
+
- Dictionary of activations for the quantized model.
|
49
|
+
"""
|
50
|
+
|
51
|
+
def _compute_activations(model: keras.Model, layer_names: List[str], data: List[np.ndarray]) -> Dict[str, np.ndarray]:
|
52
|
+
"""
|
53
|
+
Computes the activations for the specified layers of the model, and the model's output.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
model (keras.Model): The model from which to extract activations.
|
57
|
+
layer_names (List[str]): Names of the layers for which to compute activations.
|
58
|
+
data (List[np.ndarray]): Input data for the model.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict[str, np.ndarray]:
|
62
|
+
- Dictionary mapping layer names to their corresponding activations. The model's output is stored using the key MODEL_OUTPUT_KEY.
|
63
|
+
"""
|
64
|
+
# Extract the outputs of the specified layers plus the model output
|
65
|
+
_model_outputs = [model.get_layer(name).output for name in layer_names] + [model.output]
|
66
|
+
|
67
|
+
# Create a new model that outputs the intermediate and final layer outputs
|
68
|
+
intermediate_layer_model = keras.Model(inputs=model.input, outputs=_model_outputs)
|
69
|
+
predictions = intermediate_layer_model(data)
|
70
|
+
|
71
|
+
# Map layer names to their corresponding activations and return with the output predictions
|
72
|
+
activation_tensors = {layer_name: predictions[i].numpy() for i, layer_name in enumerate(layer_names)}
|
73
|
+
activation_tensors.update({MODEL_OUTPUT_KEY: predictions[-1].numpy()})
|
74
|
+
return activation_tensors
|
75
|
+
|
76
|
+
# Compute activations for the quantized model
|
77
|
+
quant_activations = _compute_activations(quantized_model,
|
78
|
+
list(float_name2quant_name.values()),
|
79
|
+
data)
|
80
|
+
# Compute activations for the float model
|
81
|
+
float_activations = _compute_activations(float_model,
|
82
|
+
list(float_name2quant_name.keys()),
|
83
|
+
data)
|
84
|
+
|
85
|
+
# Concatenate predictions if they are lists.
|
86
|
+
if isinstance(quant_activations[MODEL_OUTPUT_KEY], list):
|
87
|
+
quant_activations[MODEL_OUTPUT_KEY] = np.concatenate(quant_activations[MODEL_OUTPUT_KEY])
|
88
|
+
if isinstance(float_activations[MODEL_OUTPUT_KEY], list):
|
89
|
+
float_activations[MODEL_OUTPUT_KEY] = np.concatenate(float_activations[MODEL_OUTPUT_KEY])
|
90
|
+
|
91
|
+
return float_activations, quant_activations
|
92
|
+
|
93
|
+
def identify_quantized_compare_points(self, quantized_model: keras.Model) -> List[str]:
|
94
|
+
"""
|
95
|
+
Identifies the layers in the quantized model that are wrapped with the quantization wrapper.
|
96
|
+
These layers will serve as comparison points.
|
97
|
+
|
98
|
+
Notes:
|
99
|
+
This currently means that the quantized compare points are the linear layers that are wrapped,
|
100
|
+
but this may be changed in the future.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
quantized_model (keras.Model): The quantized model from which to identify comparison points.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
List[str]: Names of the layers wrapped with the quantization wrapper.
|
107
|
+
"""
|
108
|
+
return [layer.name for layer in quantized_model.layers if isinstance(layer, KerasQuantizationWrapper)]
|
109
|
+
|
110
|
+
def find_corresponding_float_layer(self,
|
111
|
+
quant_compare_point: str,
|
112
|
+
quantized_model: keras.Model) -> str:
|
113
|
+
"""
|
114
|
+
Finds the corresponding float model layer for a given quantized model layer.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
quant_compare_point (str): The name of the quantized model layer.
|
118
|
+
quantized_model (keras.Model): The quantized model.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
str: The name of the corresponding layer in the float model.
|
122
|
+
"""
|
123
|
+
return quantized_model.get_layer(quant_compare_point).layer.name
|
124
|
+
|
125
|
+
def extract_float_layer_names(self, float_model: keras.Model) -> List[str]:
|
126
|
+
"""
|
127
|
+
Extracts the names of all layers in the float model.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
float_model (keras.Model): The float model from which to extract layer names.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List[str]: Names of all layers in the float model.
|
134
|
+
"""
|
135
|
+
float_layers_names = [layer.name for layer in float_model.layers]
|
136
|
+
return float_layers_names
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import tensorflow as tf
|
18
|
+
|
19
|
+
from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
|
20
|
+
|
21
|
+
class KerasSimilarityFunctions(SimilarityFunctions):
|
22
|
+
"""
|
23
|
+
A class that extends SimilarityFunctions to implement similarity metrics using Keras.
|
24
|
+
Even though the names referred to are quantized and float, it can help compare between
|
25
|
+
tensors of any two models.
|
26
|
+
"""
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def compute_mse(x: np.ndarray, y: np.ndarray) -> float:
|
30
|
+
"""
|
31
|
+
Compute the Mean Squared Error (MSE) between two tensors (usually, the float and quantized tensors).
|
32
|
+
|
33
|
+
Args:
|
34
|
+
x (np.ndarray): First tensor to compare.
|
35
|
+
y (np.ndarray): Second tensor to compare.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
float: The computed MSE value.
|
39
|
+
"""
|
40
|
+
mse = tf.keras.losses.MeanSquaredError()(x, y)
|
41
|
+
return float(mse.numpy())
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def compute_cs(x: np.ndarray, y: np.ndarray) -> float:
|
45
|
+
"""
|
46
|
+
Compute the Cosine Similarity (CS) between two tensors (usually, the float and quantized tensors).
|
47
|
+
|
48
|
+
Args:
|
49
|
+
x (np.ndarray): First tensor to compare.
|
50
|
+
y (np.ndarray): Second tensor to compare.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
float: The computed CS value.
|
54
|
+
"""
|
55
|
+
cs = tf.keras.losses.CosineSimilarity()(x.flatten(), y.flatten())
|
56
|
+
return float(cs.numpy())
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def compute_sqnr(x: np.ndarray, y: np.ndarray) -> float:
|
60
|
+
"""
|
61
|
+
Compute the Signal-to-Quantization-Noise Ratio (SQNR) between two tensors (usually, the float and quantized tensors).
|
62
|
+
|
63
|
+
Args:
|
64
|
+
x (np.ndarray): First tensor to compare.
|
65
|
+
y (np.ndarray): Second tensor to compare.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
float: The computed SQNR value.
|
69
|
+
"""
|
70
|
+
signal_power = tf.reduce_mean(tf.square(x))
|
71
|
+
noise_power = tf.reduce_mean(tf.square(x - y))
|
72
|
+
sqnr = signal_power / noise_power
|
73
|
+
return float(sqnr.numpy())
|
74
|
+
|
75
|
+
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import Dict, Callable
|
16
|
+
|
17
|
+
import keras
|
18
|
+
|
19
|
+
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
|
+
|
23
|
+
from model_compression_toolkit.core.keras.reader.reader import model_reader
|
24
|
+
|
25
|
+
from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
|
26
|
+
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
27
|
+
|
28
|
+
|
29
|
+
class KerasTensorboardUtils(TensorboardUtils):
|
30
|
+
"""
|
31
|
+
A utility class for handling TensorBoard operations specific to Keras models.
|
32
|
+
This class extends the generic TensorboardUtils class and provides methods
|
33
|
+
to facilitate the visualization of quantized models and their similarity metrics
|
34
|
+
in TensorBoard.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, report_dir: str,
|
38
|
+
fw_info: FrameworkInfo,
|
39
|
+
fw_impl: FrameworkImplementation):
|
40
|
+
"""
|
41
|
+
Initialize the KerasTensorboardUtils class with the given parameters.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
report_dir (str): Directory where the TensorBoard files will be stored.
|
45
|
+
fw_info (FrameworkInfo): Information about the framework being used.
|
46
|
+
fw_impl (FrameworkImplementation): Implementation functions for the framework.
|
47
|
+
"""
|
48
|
+
super().__init__(report_dir,
|
49
|
+
fw_info,
|
50
|
+
fw_impl)
|
51
|
+
|
52
|
+
def get_graph_for_tensorboard_display(self,
|
53
|
+
quantized_model: keras.Model,
|
54
|
+
similarity_metrics: Dict[str, Dict[str, float]],
|
55
|
+
repr_dataset: Callable) -> Graph:
|
56
|
+
"""
|
57
|
+
Generate a graph suitable for TensorBoard display from the provided quantized model
|
58
|
+
and similarity metrics.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
quantized_model (keras.Model): The quantized Keras model for which the graph is to be created.
|
62
|
+
similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics
|
63
|
+
for different nodes in the model.
|
64
|
+
repr_dataset (Callable): A function or callable that provides the representative dataset.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Graph: A graph object representing the quantized model, annotated with similarity metrics.
|
68
|
+
"""
|
69
|
+
# Read the quantized model into a graph structure.
|
70
|
+
quant_graph = model_reader(quantized_model)
|
71
|
+
|
72
|
+
# Iterate over each node in the graph.
|
73
|
+
for node in quant_graph.nodes:
|
74
|
+
# Check if the node's name is in the similarity metrics for intermediate representation.
|
75
|
+
if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
|
76
|
+
# If so, add the similarity metric for intermediate representation to the node's attributes.
|
77
|
+
node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][node.name]
|
78
|
+
|
79
|
+
# Check if the node's name is in the similarity metrics for validation.
|
80
|
+
if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
|
81
|
+
# If so, add the similarity metric for validation to the node's attributes.
|
82
|
+
node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][node.name]
|
83
|
+
|
84
|
+
return quant_graph
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
# #
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
# #
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
# #
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
#
|