mct-nightly 2.1.0.20240617.451__py3-none-any.whl → 2.1.0.20240619.429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240619.429.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240619.429.dist-info}/RECORD +38 -12
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  5. model_compression_toolkit/gptq/keras/graph_info.py +1 -1
  6. model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
  7. model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
  8. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
  9. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
  10. model_compression_toolkit/xquant/__init__.py +19 -0
  11. model_compression_toolkit/xquant/common/__init__.py +15 -0
  12. model_compression_toolkit/xquant/common/constants.py +38 -0
  13. model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
  14. model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
  15. model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
  16. model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
  17. model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
  18. model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
  19. model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
  20. model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
  21. model_compression_toolkit/xquant/common/xquant_config.py +39 -0
  22. model_compression_toolkit/xquant/keras/__init__.py +15 -0
  23. model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
  24. model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
  25. model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
  26. model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
  27. model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
  28. model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
  29. model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
  30. model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
  31. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
  32. model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
  33. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
  34. model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
  35. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
  36. {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240619.429.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240619.429.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.1.0.20240617.451.dist-info → mct_nightly-2.1.0.20240619.429.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
18
+ from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
19
+ from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
20
+ from model_compression_toolkit.ptq.keras.quantization_facade import DEFAULT_KERAS_TPC
21
+ from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
22
+ from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
23
+ from model_compression_toolkit.xquant.keras.dataset_utils import KerasDatasetUtils
24
+ from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnalyzer
25
+
26
+ from model_compression_toolkit.xquant.keras.similarity_functions import KerasSimilarityFunctions
27
+ from model_compression_toolkit.xquant.keras.tensorboard_utils import KerasTensorboardUtils
28
+
29
+
30
+ class KerasReportUtils(FrameworkReportUtils):
31
+ """
32
+ Class with various utility components required for generating the report for a Keras model.
33
+ """
34
+ def __init__(self, report_dir: str):
35
+ """
36
+ Args:
37
+ report_dir: Logging dir path.
38
+ """
39
+ fw_info = DEFAULT_KERAS_INFO
40
+ fw_impl = KerasImplementation()
41
+
42
+ dataset_utils = KerasDatasetUtils()
43
+ model_folding = ModelFoldingUtils(fw_info=fw_info,
44
+ fw_impl=fw_impl,
45
+ fw_default_tpc=DEFAULT_KERAS_TPC)
46
+
47
+ similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
48
+ model_folding=model_folding,
49
+ similarity_functions=KerasSimilarityFunctions(),
50
+ model_analyzer_utils=KerasModelAnalyzer())
51
+
52
+ tb_utils = KerasTensorboardUtils(report_dir=report_dir,
53
+ fw_impl=fw_impl,
54
+ fw_info=fw_info)
55
+ super().__init__(fw_info,
56
+ fw_impl,
57
+ similarity_calculator,
58
+ dataset_utils,
59
+ model_folding,
60
+ tb_utils)
@@ -0,0 +1,136 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List, Tuple, Dict
16
+
17
+ from mct_quantizers import KerasQuantizationWrapper
18
+ from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
19
+ from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
20
+ import keras
21
+ import numpy as np
22
+
23
+
24
+ class KerasModelAnalyzer(ModelAnalyzer):
25
+ """
26
+ This class provides utilities for analyzing Keras models, specifically for
27
+ extracting activations and comparing float and quantized models.
28
+ """
29
+
30
+ def extract_model_activations(self,
31
+ float_model: keras.Model,
32
+ quantized_model: keras.Model,
33
+ float_name2quant_name: Dict[str, str],
34
+ data: List[np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
35
+ """
36
+ Extracts activations from both the float and quantized models.
37
+
38
+ Args:
39
+ float_model (keras.Model): The float model.
40
+ quantized_model (keras.Model): The quantized model.
41
+ float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
42
+ names.
43
+ data (List[np.ndarray]): Input data for which to compute activations.
44
+
45
+ Returns:
46
+ Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
47
+ - Dictionary of activations for the float model.
48
+ - Dictionary of activations for the quantized model.
49
+ """
50
+
51
+ def _compute_activations(model: keras.Model, layer_names: List[str], data: List[np.ndarray]) -> Dict[str, np.ndarray]:
52
+ """
53
+ Computes the activations for the specified layers of the model, and the model's output.
54
+
55
+ Args:
56
+ model (keras.Model): The model from which to extract activations.
57
+ layer_names (List[str]): Names of the layers for which to compute activations.
58
+ data (List[np.ndarray]): Input data for the model.
59
+
60
+ Returns:
61
+ Dict[str, np.ndarray]:
62
+ - Dictionary mapping layer names to their corresponding activations. The model's output is stored using the key MODEL_OUTPUT_KEY.
63
+ """
64
+ # Extract the outputs of the specified layers plus the model output
65
+ _model_outputs = [model.get_layer(name).output for name in layer_names] + [model.output]
66
+
67
+ # Create a new model that outputs the intermediate and final layer outputs
68
+ intermediate_layer_model = keras.Model(inputs=model.input, outputs=_model_outputs)
69
+ predictions = intermediate_layer_model(data)
70
+
71
+ # Map layer names to their corresponding activations and return with the output predictions
72
+ activation_tensors = {layer_name: predictions[i].numpy() for i, layer_name in enumerate(layer_names)}
73
+ activation_tensors.update({MODEL_OUTPUT_KEY: predictions[-1].numpy()})
74
+ return activation_tensors
75
+
76
+ # Compute activations for the quantized model
77
+ quant_activations = _compute_activations(quantized_model,
78
+ list(float_name2quant_name.values()),
79
+ data)
80
+ # Compute activations for the float model
81
+ float_activations = _compute_activations(float_model,
82
+ list(float_name2quant_name.keys()),
83
+ data)
84
+
85
+ # Concatenate predictions if they are lists.
86
+ if isinstance(quant_activations[MODEL_OUTPUT_KEY], list):
87
+ quant_activations[MODEL_OUTPUT_KEY] = np.concatenate(quant_activations[MODEL_OUTPUT_KEY])
88
+ if isinstance(float_activations[MODEL_OUTPUT_KEY], list):
89
+ float_activations[MODEL_OUTPUT_KEY] = np.concatenate(float_activations[MODEL_OUTPUT_KEY])
90
+
91
+ return float_activations, quant_activations
92
+
93
+ def identify_quantized_compare_points(self, quantized_model: keras.Model) -> List[str]:
94
+ """
95
+ Identifies the layers in the quantized model that are wrapped with the quantization wrapper.
96
+ These layers will serve as comparison points.
97
+
98
+ Notes:
99
+ This currently means that the quantized compare points are the linear layers that are wrapped,
100
+ but this may be changed in the future.
101
+
102
+ Args:
103
+ quantized_model (keras.Model): The quantized model from which to identify comparison points.
104
+
105
+ Returns:
106
+ List[str]: Names of the layers wrapped with the quantization wrapper.
107
+ """
108
+ return [layer.name for layer in quantized_model.layers if isinstance(layer, KerasQuantizationWrapper)]
109
+
110
+ def find_corresponding_float_layer(self,
111
+ quant_compare_point: str,
112
+ quantized_model: keras.Model) -> str:
113
+ """
114
+ Finds the corresponding float model layer for a given quantized model layer.
115
+
116
+ Args:
117
+ quant_compare_point (str): The name of the quantized model layer.
118
+ quantized_model (keras.Model): The quantized model.
119
+
120
+ Returns:
121
+ str: The name of the corresponding layer in the float model.
122
+ """
123
+ return quantized_model.get_layer(quant_compare_point).layer.name
124
+
125
+ def extract_float_layer_names(self, float_model: keras.Model) -> List[str]:
126
+ """
127
+ Extracts the names of all layers in the float model.
128
+
129
+ Args:
130
+ float_model (keras.Model): The float model from which to extract layer names.
131
+
132
+ Returns:
133
+ List[str]: Names of all layers in the float model.
134
+ """
135
+ float_layers_names = [layer.name for layer in float_model.layers]
136
+ return float_layers_names
@@ -0,0 +1,75 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+ from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
20
+
21
+ class KerasSimilarityFunctions(SimilarityFunctions):
22
+ """
23
+ A class that extends SimilarityFunctions to implement similarity metrics using Keras.
24
+ Even though the names referred to are quantized and float, it can help compare between
25
+ tensors of any two models.
26
+ """
27
+
28
+ @staticmethod
29
+ def compute_mse(x: np.ndarray, y: np.ndarray) -> float:
30
+ """
31
+ Compute the Mean Squared Error (MSE) between two tensors (usually, the float and quantized tensors).
32
+
33
+ Args:
34
+ x (np.ndarray): First tensor to compare.
35
+ y (np.ndarray): Second tensor to compare.
36
+
37
+ Returns:
38
+ float: The computed MSE value.
39
+ """
40
+ mse = tf.keras.losses.MeanSquaredError()(x, y)
41
+ return float(mse.numpy())
42
+
43
+ @staticmethod
44
+ def compute_cs(x: np.ndarray, y: np.ndarray) -> float:
45
+ """
46
+ Compute the Cosine Similarity (CS) between two tensors (usually, the float and quantized tensors).
47
+
48
+ Args:
49
+ x (np.ndarray): First tensor to compare.
50
+ y (np.ndarray): Second tensor to compare.
51
+
52
+ Returns:
53
+ float: The computed CS value.
54
+ """
55
+ cs = tf.keras.losses.CosineSimilarity()(x.flatten(), y.flatten())
56
+ return float(cs.numpy())
57
+
58
+ @staticmethod
59
+ def compute_sqnr(x: np.ndarray, y: np.ndarray) -> float:
60
+ """
61
+ Compute the Signal-to-Quantization-Noise Ratio (SQNR) between two tensors (usually, the float and quantized tensors).
62
+
63
+ Args:
64
+ x (np.ndarray): First tensor to compare.
65
+ y (np.ndarray): Second tensor to compare.
66
+
67
+ Returns:
68
+ float: The computed SQNR value.
69
+ """
70
+ signal_power = tf.reduce_mean(tf.square(x))
71
+ noise_power = tf.reduce_mean(tf.square(x - y))
72
+ sqnr = signal_power / noise_power
73
+ return float(sqnr.numpy())
74
+
75
+
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, Callable
16
+
17
+ import keras
18
+
19
+ from model_compression_toolkit.core.common import Graph
20
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
+
23
+ from model_compression_toolkit.core.keras.reader.reader import model_reader
24
+
25
+ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
26
+ from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
27
+
28
+
29
+ class KerasTensorboardUtils(TensorboardUtils):
30
+ """
31
+ A utility class for handling TensorBoard operations specific to Keras models.
32
+ This class extends the generic TensorboardUtils class and provides methods
33
+ to facilitate the visualization of quantized models and their similarity metrics
34
+ in TensorBoard.
35
+ """
36
+
37
+ def __init__(self, report_dir: str,
38
+ fw_info: FrameworkInfo,
39
+ fw_impl: FrameworkImplementation):
40
+ """
41
+ Initialize the KerasTensorboardUtils class with the given parameters.
42
+
43
+ Args:
44
+ report_dir (str): Directory where the TensorBoard files will be stored.
45
+ fw_info (FrameworkInfo): Information about the framework being used.
46
+ fw_impl (FrameworkImplementation): Implementation functions for the framework.
47
+ """
48
+ super().__init__(report_dir,
49
+ fw_info,
50
+ fw_impl)
51
+
52
+ def get_graph_for_tensorboard_display(self,
53
+ quantized_model: keras.Model,
54
+ similarity_metrics: Dict[str, Dict[str, float]],
55
+ repr_dataset: Callable) -> Graph:
56
+ """
57
+ Generate a graph suitable for TensorBoard display from the provided quantized model
58
+ and similarity metrics.
59
+
60
+ Args:
61
+ quantized_model (keras.Model): The quantized Keras model for which the graph is to be created.
62
+ similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics
63
+ for different nodes in the model.
64
+ repr_dataset (Callable): A function or callable that provides the representative dataset.
65
+
66
+ Returns:
67
+ Graph: A graph object representing the quantized model, annotated with similarity metrics.
68
+ """
69
+ # Read the quantized model into a graph structure.
70
+ quant_graph = model_reader(quantized_model)
71
+
72
+ # Iterate over each node in the graph.
73
+ for node in quant_graph.nodes:
74
+ # Check if the node's name is in the similarity metrics for intermediate representation.
75
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
76
+ # If so, add the similarity metric for intermediate representation to the node's attributes.
77
+ node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][node.name]
78
+
79
+ # Check if the node's name is in the similarity metrics for validation.
80
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
81
+ # If so, add the similarity metric for validation to the node's attributes.
82
+ node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][node.name]
83
+
84
+ return quant_graph
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ # #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any, Callable
17
+
18
+ from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
19
+ import numpy as np
20
+
21
+ import torch
22
+
23
+
24
+ class PytorchDatasetUtils(DatasetUtils):
25
+ """
26
+ Class with helpful methods for handling different kinds of Pytorch datasets from the user.
27
+ """
28
+ @staticmethod
29
+ def prepare_dataset(dataset: Callable, is_validation: bool, device: str = None):
30
+ """
31
+ Prepare the dataset so calling it will return only inputs for the model (like in the case
32
+ of the representative dataset). For example, when the validation dataset is used, the labels
33
+ should be removed.
34
+
35
+ Args:
36
+ dataset: Dataset to prepare.
37
+ is_validation: Whether it's validation dataset or not.
38
+ device: Device to transfer the data to.
39
+
40
+ Returns:
41
+ Generator to use for retrieving the dataset inputs.
42
+
43
+ """
44
+
45
+ def process_data(data: Any, is_validation: bool, device: str):
46
+ """
47
+ Processes individual data samples: Transfer them to the device, convert to torch tensors if needed,
48
+ remove labels if this is a validation dataset.
49
+
50
+ Args:
51
+ data: The data sample to process.
52
+ is_validation: A flag indicating if this is a validation dataset.
53
+ device: The device to transfer the data to.
54
+
55
+ Returns:
56
+ The data as torch tensors on the desired device.
57
+ """
58
+
59
+ def transfer_to_device(_data):
60
+ if isinstance(_data, np.ndarray):
61
+ return torch.from_numpy(_data).to(device)
62
+ return _data.to(device)
63
+
64
+ if is_validation:
65
+ inputs = data[0] # Assume data[0] contains the inputs and data[1] the labels
66
+ if isinstance(inputs, list):
67
+ data = [transfer_to_device(t) for t in inputs]
68
+ else:
69
+ data = [transfer_to_device(inputs)]
70
+ else:
71
+ data = [transfer_to_device(t) for t in data]
72
+
73
+ return data
74
+
75
+ for x in dataset():
76
+ yield process_data(x, is_validation, device)
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable
17
+
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
20
+ from model_compression_toolkit.xquant import XQuantConfig
21
+ from model_compression_toolkit.logger import Logger
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.xquant.pytorch.pytorch_report_utils import PytorchReportUtils
25
+ import torch
26
+
27
+ def xquant_report_pytorch_experimental(float_model: torch.nn.Module,
28
+ quantized_model: torch.nn.Module,
29
+ repr_dataset: Callable,
30
+ validation_dataset: Callable,
31
+ xquant_config: XQuantConfig):
32
+ """
33
+ Generate an explainable quantization report for a quantized Pytorch model.
34
+
35
+ Args:
36
+ float_model (torch.nn.Module): The original floating-point Pytorch model.
37
+ quantized_model (torch.nn.Module): The quantized Pytorch model.
38
+ repr_dataset (Callable): The representative dataset used during quantization.
39
+ validation_dataset (Callable): The validation dataset used for evaluation.
40
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
41
+
42
+ Returns:
43
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
44
+ """
45
+ # Initialize the logger with the report directory.
46
+ Logger.set_log_file(log_folder=xquant_config.report_dir)
47
+
48
+ pytorch_report_utils = PytorchReportUtils(xquant_config.report_dir)
49
+
50
+ _collected_data = core_report_generator(float_model=float_model,
51
+ quantized_model=quantized_model,
52
+ repr_dataset=repr_dataset,
53
+ validation_dataset=validation_dataset,
54
+ fw_report_utils=pytorch_report_utils,
55
+ xquant_config=xquant_config)
56
+
57
+ return _collected_data
58
+
59
+ else:
60
+ def xquant_report_pytorch_experimental(*args, **kwargs):
61
+ Logger.critical("PyTorch must be installed to use 'xquant_report_pytorch_experimental'. "
62
+ "The 'torch' package is missing.") # pragma: no cover
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, List, Tuple
16
+
17
+ import torch
18
+ from mct_quantizers.pytorch.quantize_wrapper import PytorchQuantizationWrapper
19
+ from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
20
+
21
+ from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
22
+
23
+
24
+ class PytorchModelAnalyzer(ModelAnalyzer):
25
+ """
26
+ This class provides utilities for analyzing Pytorch models, specifically for
27
+ extracting activations and comparing float and quantized models.
28
+ """
29
+
30
+ def extract_model_activations(self,
31
+ float_model: torch.nn.Module,
32
+ quantized_model: torch.nn.Module,
33
+ float_name2quant_name: Dict[str, str],
34
+ data: List[torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
35
+ """
36
+ Extracts activations from both the float and quantized models.
37
+
38
+ Args:
39
+ float_model (torch.nn.Module): The float model.
40
+ quantized_model (torch.nn.Module): The quantized model.
41
+ float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
42
+ names.
43
+ data (List[torch.Tensor]): Input data for which to compute activations.
44
+
45
+ Returns:
46
+ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
47
+ - Dictionary of activations for the float model.
48
+ - Dictionary of activations for the quantized model.
49
+ """
50
+
51
+ def _compute_activations(name: str, activations: dict):
52
+ """
53
+ Creates a hook function to capture the activations of a layer.
54
+
55
+ Args:
56
+ name (str): The name of the layer.
57
+ activations (dict): The dictionary to store the activations.
58
+
59
+ Returns:
60
+ hook (function): The hook function to register with the layer.
61
+ """
62
+ def hook(model, input, output):
63
+ activations[name] = output.detach()
64
+
65
+ return hook
66
+
67
+ # Initialize dictionaries to store activations for both models
68
+ activations_float = {}
69
+ activations_quant = {}
70
+
71
+ # Register hooks for all layers in the float model
72
+ for layer_name in float_name2quant_name.keys():
73
+ layer = dict([*float_model.named_modules()])[layer_name]
74
+ layer.register_forward_hook(_compute_activations(layer_name, activations_float))
75
+
76
+ # Register hooks for all layers in the quantized model
77
+ for layer_name in float_name2quant_name.values():
78
+ layer = dict([*quantized_model.named_modules()])[layer_name]
79
+ layer.register_forward_hook(_compute_activations(layer_name, activations_quant))
80
+
81
+ # Perform a forward pass with the input data and capture activations
82
+ with torch.no_grad():
83
+ float_predictions = float_model(*data)
84
+ quant_predictions = quantized_model(*data)
85
+
86
+ activations_float[MODEL_OUTPUT_KEY] = float_predictions
87
+ activations_quant[MODEL_OUTPUT_KEY] = quant_predictions
88
+
89
+ return activations_float, activations_quant
90
+
91
+ def identify_quantized_compare_points(self,
92
+ quantized_model: torch.nn.Module) -> List[str]:
93
+ """
94
+ Identifies points in the quantized model to compare with the float model.
95
+
96
+ Args:
97
+ quantized_model (torch.nn.Module): The quantized model.
98
+
99
+ Returns:
100
+ List[str]: A list of layer names in the quantized model to compare.
101
+ """
102
+ return [n for n, m in quantized_model.named_modules() if isinstance(m, PytorchQuantizationWrapper)]
103
+
104
+ def find_corresponding_float_layer(self,
105
+ quant_compare_point: str,
106
+ quantized_model: torch.nn.Module) -> str:
107
+ """
108
+ Finds the corresponding float model layer for a given quantized model layer.
109
+ In pytorch, we assume the name is the same in the float model, thus we return quant_compare_point.
110
+
111
+ Args:
112
+ quant_compare_point (str): The name of the layer in the quantized model.
113
+ quantized_model (torch.nn.Module): The quantized model.
114
+
115
+ Returns:
116
+ str: The name of the corresponding layer in the float model.
117
+ """
118
+ return quant_compare_point
119
+
120
+ def extract_float_layer_names(self,
121
+ float_model: torch.nn.Module) -> List[str]:
122
+ """
123
+ Extracts the names of all layers in the float model.
124
+
125
+ Args:
126
+ float_model (torch.nn.Module): The float model.
127
+
128
+ Returns:
129
+ List[str]: A list of layer names in the float model.
130
+ """
131
+ return [n for n, m in float_model.named_modules()]
132
+