mct-nightly 2.1.0.20240616.65727__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +43 -17
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/functional_node.py +3 -3
  5. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
  6. model_compression_toolkit/core/pytorch/constants.py +1 -1
  7. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
  8. model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
  9. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  10. model_compression_toolkit/gptq/keras/graph_info.py +1 -1
  11. model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
  12. model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
  13. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
  15. model_compression_toolkit/xquant/__init__.py +19 -0
  16. model_compression_toolkit/xquant/common/__init__.py +15 -0
  17. model_compression_toolkit/xquant/common/constants.py +38 -0
  18. model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
  19. model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
  20. model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
  21. model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
  22. model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
  23. model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
  24. model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
  25. model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
  26. model_compression_toolkit/xquant/common/xquant_config.py +39 -0
  27. model_compression_toolkit/xquant/keras/__init__.py +15 -0
  28. model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
  29. model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
  30. model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
  31. model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
  32. model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
  33. model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
  34. model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
  35. model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
  36. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
  37. model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
  38. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
  39. model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
  40. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
  41. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
  42. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
  43. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common import Graph
17
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
18
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
+
20
+
21
+ from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
22
+ from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG
23
+ from model_compression_toolkit.logger import Logger
24
+
25
+
26
+ from typing import Any, Dict, Callable
27
+
28
+
29
+ class TensorboardUtils:
30
+ """
31
+ Utility class for handling Tensorboard operations like adding graph to display
32
+ and histograms on the float model.
33
+ """
34
+
35
+ def __init__(self,
36
+ report_dir: str,
37
+ fw_info: FrameworkInfo,
38
+ fw_impl: FrameworkImplementation):
39
+ """
40
+ Initialize the TensorboardUtils.
41
+
42
+ Args:
43
+ report_dir (str): Directory where Tensorboard logs will be stored.
44
+ fw_info (FrameworkInfo): Framework-specific information.
45
+ fw_impl (FrameworkImplementation): Framework-specific implementation.
46
+ """
47
+ self.fw_impl = fw_impl
48
+ self.fw_info = fw_info
49
+ self.tb_writer = TensorboardWriter(report_dir, fw_info)
50
+ Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
51
+
52
+ def get_graph_for_tensorboard_display(self,
53
+ quantized_model: Any,
54
+ similarity_metrics: Dict[str, Any],
55
+ repr_dataset: Callable) -> Graph:
56
+ """
57
+ Get the graph for Tensorboard display. The framework-specific implementations
58
+ (like KerasTensorboardUtils and PytorchTensorboardUtils) should implement this
59
+ as it differs between them when combining the similarity metrics into the graph.
60
+
61
+ Args:
62
+ quantized_model (Any): The quantized model.
63
+ similarity_metrics (Dict[str, Any]): Metrics for model similarity.
64
+ repr_dataset (Callable): Representative dataset function.
65
+
66
+ Returns:
67
+ Graph: The generated graph for Tensorboard display.
68
+ """
69
+ Logger.critical("This method should be implemented by the framework-specific TensorboardUtils.") # pragma: no cover
70
+
71
+ def add_histograms_to_tensorboard(self,
72
+ graph: Graph):
73
+ """
74
+ Add histograms to Tensorboard from a graph that holds these statistics.
75
+
76
+ Args:
77
+ graph (Graph): Graph with histograms to add to the tensorboard.
78
+ """
79
+ self.tb_writer.add_histograms(graph, TENSORBOARD_DEFAULT_TAG)
80
+
81
+ def add_graph_to_tensorboard(self,
82
+ quantized_model: Any,
83
+ similarity_metrics: Dict[str, Any],
84
+ repr_dataset: Callable):
85
+ """
86
+ Add a graph to Tensorboard. The graph represents the quantized graph
87
+ with the similarity metrics that were measured in different nodes.
88
+
89
+ Args:
90
+ quantized_model (Any): The quantized model.
91
+ similarity_metrics (Dict[str, Any]): The similarity metrics that were collected.
92
+ repr_dataset (Callable): Representative dataset to use (if needed, like in pytorch case).
93
+ """
94
+ # Generate the quantized graph with similarity metrics.
95
+ tb_graph = self.get_graph_for_tensorboard_display(quantized_model=quantized_model,
96
+ similarity_metrics=similarity_metrics,
97
+ repr_dataset=repr_dataset)
98
+
99
+ self.tb_writer.add_graph(tb_graph, TENSORBOARD_DEFAULT_TAG)
100
+
101
+
@@ -0,0 +1,39 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Dict, Callable
17
+
18
+
19
+ class XQuantConfig:
20
+ """
21
+ Configuration for generating the report.
22
+ It allows to set the log dir that the report will be saved in and to add similarity metrics
23
+ to measure between tensors of the two models.
24
+ """
25
+
26
+ def __init__(self,
27
+ report_dir: str,
28
+ custom_similarity_metrics: Dict[str, Callable] = None):
29
+ """
30
+ Initializes the configuration for explainable quantization.
31
+
32
+ Args:
33
+ report_dir (str): Directory where the reports will be saved.
34
+ custom_similarity_metrics (Dict[str, Callable]): Custom similarity metrics to be computed between tensors
35
+ of the two models. The dictionary keys are similarity metric names and the values are callables that implement the
36
+ similarity metric computation.
37
+ """
38
+ self.report_dir = report_dir
39
+ self.custom_similarity_metrics = custom_similarity_metrics
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
@@ -0,0 +1,57 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any
17
+
18
+ from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
19
+
20
+
21
+ class KerasDatasetUtils(DatasetUtils):
22
+ """
23
+ Class with helpful methods for handling different kinds of Keras datasets from the user.
24
+ """
25
+
26
+ @staticmethod
27
+ def prepare_dataset(dataset: Any, is_validation: bool, device: str = None):
28
+ """
29
+ Prepare the dataset so calling it will return only inputs for the model (like in the case
30
+ of the representative dataset). For example, when the validation dataset is used, the labels
31
+ should be removed.
32
+
33
+ Args:
34
+ dataset: Dataset to prepare.
35
+ is_validation: Whether it's validation dataset or not.
36
+ device: Device to transfer the data to.
37
+
38
+ Returns:
39
+ Generator to use for retrieving the dataset inputs.
40
+ """
41
+ def process_data(x: Any, is_validation: bool):
42
+ """
43
+ Processes individual data samples to transfer them to the device and convert to torch tensors if needed.
44
+
45
+ Args:
46
+ data: The data sample to process.
47
+ is_validation: A flag indicating if this is a validation dataset.
48
+ device: The device to transfer the data to.
49
+
50
+ Returns:
51
+ The data as torch tensors on the desired device.
52
+ """
53
+ return x[0] if is_validation else x # Assume data[0] contains the inputs and data[1] the labels
54
+
55
+ for x in dataset():
56
+ yield process_data(x, is_validation)
57
+
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Dict, Any
17
+
18
+ from model_compression_toolkit.constants import FOUND_TF
19
+ from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
20
+ from model_compression_toolkit.xquant import XQuantConfig
21
+ from model_compression_toolkit.logger import Logger
22
+
23
+ if FOUND_TF:
24
+ import keras
25
+ from model_compression_toolkit.xquant.keras.keras_report_utils import KerasReportUtils
26
+
27
+ def xquant_report_keras_experimental(float_model: keras.Model,
28
+ quantized_model: keras.Model,
29
+ repr_dataset: Callable,
30
+ validation_dataset: Callable,
31
+ xquant_config: XQuantConfig) -> Dict[str, Any]:
32
+ """
33
+ Generate an explainable quantization report for a quantized Keras model.
34
+
35
+ Args:
36
+ float_model (keras.Model): The original floating-point Keras model.
37
+ quantized_model (keras.Model): The quantized Keras model.
38
+ repr_dataset (Callable): The representative dataset used during quantization for similarity metrics computation.
39
+ validation_dataset (Callable): The validation dataset used for evaluation for similarity metrics computation.
40
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
41
+
42
+ Returns:
43
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
44
+ """
45
+ # Initialize the logger with the report directory.
46
+ Logger.set_log_file(log_folder=xquant_config.report_dir)
47
+
48
+ # Initialize a utility class for handling Keras-specific reporting tasks.
49
+ keras_report_utils = KerasReportUtils(xquant_config.report_dir)
50
+
51
+ # Create the report after collecting useful data like histograms and similarity metrics.
52
+ _collected_data = core_report_generator(float_model=float_model,
53
+ quantized_model=quantized_model,
54
+ repr_dataset=repr_dataset,
55
+ validation_dataset=validation_dataset,
56
+ fw_report_utils=keras_report_utils,
57
+ xquant_config=xquant_config)
58
+
59
+ return _collected_data
60
+ else:
61
+ def xquant_report_keras_experimental(*args, **kwargs):
62
+ Logger.critical("Tensorflow must be installed to use xquant_report_keras_experimental. "
63
+ "The 'tensorflow' package is missing.") # pragma: no cover
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
18
+ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
19
+ from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
20
+ from model_compression_toolkit.ptq.keras.quantization_facade import DEFAULT_KERAS_TPC
21
+ from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
22
+ from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
23
+ from model_compression_toolkit.xquant.keras.dataset_utils import KerasDatasetUtils
24
+ from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnalyzer
25
+
26
+ from model_compression_toolkit.xquant.keras.similarity_functions import KerasSimilarityFunctions
27
+ from model_compression_toolkit.xquant.keras.tensorboard_utils import KerasTensorboardUtils
28
+
29
+
30
+ class KerasReportUtils(FrameworkReportUtils):
31
+ """
32
+ Class with various utility components required for generating the report for a Keras model.
33
+ """
34
+ def __init__(self, report_dir: str):
35
+ """
36
+ Args:
37
+ report_dir: Logging dir path.
38
+ """
39
+ fw_info = DEFAULT_KERAS_INFO
40
+ fw_impl = KerasImplementation()
41
+
42
+ dataset_utils = KerasDatasetUtils()
43
+ model_folding = ModelFoldingUtils(fw_info=fw_info,
44
+ fw_impl=fw_impl,
45
+ fw_default_tpc=DEFAULT_KERAS_TPC)
46
+
47
+ similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
48
+ model_folding=model_folding,
49
+ similarity_functions=KerasSimilarityFunctions(),
50
+ model_analyzer_utils=KerasModelAnalyzer())
51
+
52
+ tb_utils = KerasTensorboardUtils(report_dir=report_dir,
53
+ fw_impl=fw_impl,
54
+ fw_info=fw_info)
55
+ super().__init__(fw_info,
56
+ fw_impl,
57
+ similarity_calculator,
58
+ dataset_utils,
59
+ model_folding,
60
+ tb_utils)
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Tuple, Dict
16
+
17
+ from mct_quantizers import KerasQuantizationWrapper
18
+ from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
19
+ from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
20
+ import keras
21
+ import numpy as np
22
+
23
+
24
+ class KerasModelAnalyzer(ModelAnalyzer):
25
+ """
26
+ This class provides utilities for analyzing Keras models, specifically for
27
+ extracting activations and comparing float and quantized models.
28
+ """
29
+
30
+ def extract_model_activations(self,
31
+ float_model: keras.Model,
32
+ quantized_model: keras.Model,
33
+ float_name2quant_name: Dict[str, str],
34
+ data: List[np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
35
+ """
36
+ Extracts activations from both the float and quantized models.
37
+
38
+ Args:
39
+ float_model (keras.Model): The float model.
40
+ quantized_model (keras.Model): The quantized model.
41
+ float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
42
+ names.
43
+ data (List[np.ndarray]): Input data for which to compute activations.
44
+
45
+ Returns:
46
+ Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
47
+ - Dictionary of activations for the float model.
48
+ - Dictionary of activations for the quantized model.
49
+ """
50
+
51
+ def _compute_activations(model: keras.Model, layer_names: List[str], data: List[np.ndarray]) -> Dict[str, np.ndarray]:
52
+ """
53
+ Computes the activations for the specified layers of the model, and the model's output.
54
+
55
+ Args:
56
+ model (keras.Model): The model from which to extract activations.
57
+ layer_names (List[str]): Names of the layers for which to compute activations.
58
+ data (List[np.ndarray]): Input data for the model.
59
+
60
+ Returns:
61
+ Dict[str, np.ndarray]:
62
+ - Dictionary mapping layer names to their corresponding activations. The model's output is stored using the key MODEL_OUTPUT_KEY.
63
+ """
64
+ # Extract the outputs of the specified layers plus the model output
65
+ _model_outputs = [model.get_layer(name).output for name in layer_names] + [model.output]
66
+
67
+ # Create a new model that outputs the intermediate and final layer outputs
68
+ intermediate_layer_model = keras.Model(inputs=model.input, outputs=_model_outputs)
69
+ predictions = intermediate_layer_model(data)
70
+
71
+ # Map layer names to their corresponding activations and return with the output predictions
72
+ activation_tensors = {layer_name: predictions[i].numpy() for i, layer_name in enumerate(layer_names)}
73
+ activation_tensors.update({MODEL_OUTPUT_KEY: predictions[-1].numpy()})
74
+ return activation_tensors
75
+
76
+ # Compute activations for the quantized model
77
+ quant_activations = _compute_activations(quantized_model,
78
+ list(float_name2quant_name.values()),
79
+ data)
80
+ # Compute activations for the float model
81
+ float_activations = _compute_activations(float_model,
82
+ list(float_name2quant_name.keys()),
83
+ data)
84
+
85
+ # Concatenate predictions if they are lists.
86
+ if isinstance(quant_activations[MODEL_OUTPUT_KEY], list):
87
+ quant_activations[MODEL_OUTPUT_KEY] = np.concatenate(quant_activations[MODEL_OUTPUT_KEY])
88
+ if isinstance(float_activations[MODEL_OUTPUT_KEY], list):
89
+ float_activations[MODEL_OUTPUT_KEY] = np.concatenate(float_activations[MODEL_OUTPUT_KEY])
90
+
91
+ return float_activations, quant_activations
92
+
93
+ def identify_quantized_compare_points(self, quantized_model: keras.Model) -> List[str]:
94
+ """
95
+ Identifies the layers in the quantized model that are wrapped with the quantization wrapper.
96
+ These layers will serve as comparison points.
97
+
98
+ Notes:
99
+ This currently means that the quantized compare points are the linear layers that are wrapped,
100
+ but this may be changed in the future.
101
+
102
+ Args:
103
+ quantized_model (keras.Model): The quantized model from which to identify comparison points.
104
+
105
+ Returns:
106
+ List[str]: Names of the layers wrapped with the quantization wrapper.
107
+ """
108
+ return [layer.name for layer in quantized_model.layers if isinstance(layer, KerasQuantizationWrapper)]
109
+
110
+ def find_corresponding_float_layer(self,
111
+ quant_compare_point: str,
112
+ quantized_model: keras.Model) -> str:
113
+ """
114
+ Finds the corresponding float model layer for a given quantized model layer.
115
+
116
+ Args:
117
+ quant_compare_point (str): The name of the quantized model layer.
118
+ quantized_model (keras.Model): The quantized model.
119
+
120
+ Returns:
121
+ str: The name of the corresponding layer in the float model.
122
+ """
123
+ return quantized_model.get_layer(quant_compare_point).layer.name
124
+
125
+ def extract_float_layer_names(self, float_model: keras.Model) -> List[str]:
126
+ """
127
+ Extracts the names of all layers in the float model.
128
+
129
+ Args:
130
+ float_model (keras.Model): The float model from which to extract layer names.
131
+
132
+ Returns:
133
+ List[str]: Names of all layers in the float model.
134
+ """
135
+ float_layers_names = [layer.name for layer in float_model.layers]
136
+ return float_layers_names
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+ from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
20
+
21
+ class KerasSimilarityFunctions(SimilarityFunctions):
22
+ """
23
+ A class that extends SimilarityFunctions to implement similarity metrics using Keras.
24
+ Even though the names referred to are quantized and float, it can help compare between
25
+ tensors of any two models.
26
+ """
27
+
28
+ @staticmethod
29
+ def compute_mse(x: np.ndarray, y: np.ndarray) -> float:
30
+ """
31
+ Compute the Mean Squared Error (MSE) between two tensors (usually, the float and quantized tensors).
32
+
33
+ Args:
34
+ x (np.ndarray): First tensor to compare.
35
+ y (np.ndarray): Second tensor to compare.
36
+
37
+ Returns:
38
+ float: The computed MSE value.
39
+ """
40
+ mse = tf.keras.losses.MeanSquaredError()(x, y)
41
+ return float(mse.numpy())
42
+
43
+ @staticmethod
44
+ def compute_cs(x: np.ndarray, y: np.ndarray) -> float:
45
+ """
46
+ Compute the Cosine Similarity (CS) between two tensors (usually, the float and quantized tensors).
47
+
48
+ Args:
49
+ x (np.ndarray): First tensor to compare.
50
+ y (np.ndarray): Second tensor to compare.
51
+
52
+ Returns:
53
+ float: The computed CS value.
54
+ """
55
+ cs = tf.keras.losses.CosineSimilarity()(x.flatten(), y.flatten())
56
+ return float(cs.numpy())
57
+
58
+ @staticmethod
59
+ def compute_sqnr(x: np.ndarray, y: np.ndarray) -> float:
60
+ """
61
+ Compute the Signal-to-Quantization-Noise Ratio (SQNR) between two tensors (usually, the float and quantized tensors).
62
+
63
+ Args:
64
+ x (np.ndarray): First tensor to compare.
65
+ y (np.ndarray): Second tensor to compare.
66
+
67
+ Returns:
68
+ float: The computed SQNR value.
69
+ """
70
+ signal_power = tf.reduce_mean(tf.square(x))
71
+ noise_power = tf.reduce_mean(tf.square(x - y))
72
+ sqnr = signal_power / noise_power
73
+ return float(sqnr.numpy())
74
+
75
+
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, Callable
16
+
17
+ import keras
18
+
19
+ from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
+
23
+ from model_compression_toolkit.core.keras.reader.reader import model_reader
24
+
25
+ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
26
+ from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
27
+
28
+
29
+ class KerasTensorboardUtils(TensorboardUtils):
30
+ """
31
+ A utility class for handling TensorBoard operations specific to Keras models.
32
+ This class extends the generic TensorboardUtils class and provides methods
33
+ to facilitate the visualization of quantized models and their similarity metrics
34
+ in TensorBoard.
35
+ """
36
+
37
+ def __init__(self, report_dir: str,
38
+ fw_info: FrameworkInfo,
39
+ fw_impl: FrameworkImplementation):
40
+ """
41
+ Initialize the KerasTensorboardUtils class with the given parameters.
42
+
43
+ Args:
44
+ report_dir (str): Directory where the TensorBoard files will be stored.
45
+ fw_info (FrameworkInfo): Information about the framework being used.
46
+ fw_impl (FrameworkImplementation): Implementation functions for the framework.
47
+ """
48
+ super().__init__(report_dir,
49
+ fw_info,
50
+ fw_impl)
51
+
52
+ def get_graph_for_tensorboard_display(self,
53
+ quantized_model: keras.Model,
54
+ similarity_metrics: Dict[str, Dict[str, float]],
55
+ repr_dataset: Callable) -> Graph:
56
+ """
57
+ Generate a graph suitable for TensorBoard display from the provided quantized model
58
+ and similarity metrics.
59
+
60
+ Args:
61
+ quantized_model (keras.Model): The quantized Keras model for which the graph is to be created.
62
+ similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics
63
+ for different nodes in the model.
64
+ repr_dataset (Callable): A function or callable that provides the representative dataset.
65
+
66
+ Returns:
67
+ Graph: A graph object representing the quantized model, annotated with similarity metrics.
68
+ """
69
+ # Read the quantized model into a graph structure.
70
+ quant_graph = model_reader(quantized_model)
71
+
72
+ # Iterate over each node in the graph.
73
+ for node in quant_graph.nodes:
74
+ # Check if the node's name is in the similarity metrics for intermediate representation.
75
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
76
+ # If so, add the similarity metric for intermediate representation to the node's attributes.
77
+ node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][node.name]
78
+
79
+ # Check if the node's name is in the similarity metrics for validation.
80
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
81
+ # If so, add the similarity metric for validation to the node's attributes.
82
+ node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][node.name]
83
+
84
+ return quant_graph
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ # #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #