mct-nightly 2.1.0.20240616.65727__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +43 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/functional_node.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
- model_compression_toolkit/core/pytorch/constants.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
- model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
- model_compression_toolkit/gptq/keras/graph_info.py +1 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
- model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
- model_compression_toolkit/xquant/__init__.py +19 -0
- model_compression_toolkit/xquant/common/__init__.py +15 -0
- model_compression_toolkit/xquant/common/constants.py +38 -0
- model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
- model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
- model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
- model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
- model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
- model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
- model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
- model_compression_toolkit/xquant/common/xquant_config.py +39 -0
- model_compression_toolkit/xquant/keras/__init__.py +15 -0
- model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
- model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
- model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
- model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
- model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
- model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
- model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
- model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Any, Callable
|
17
|
+
|
18
|
+
from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
|
19
|
+
import numpy as np
|
20
|
+
|
21
|
+
import torch
|
22
|
+
|
23
|
+
|
24
|
+
class PytorchDatasetUtils(DatasetUtils):
|
25
|
+
"""
|
26
|
+
Class with helpful methods for handling different kinds of Pytorch datasets from the user.
|
27
|
+
"""
|
28
|
+
@staticmethod
|
29
|
+
def prepare_dataset(dataset: Callable, is_validation: bool, device: str = None):
|
30
|
+
"""
|
31
|
+
Prepare the dataset so calling it will return only inputs for the model (like in the case
|
32
|
+
of the representative dataset). For example, when the validation dataset is used, the labels
|
33
|
+
should be removed.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
dataset: Dataset to prepare.
|
37
|
+
is_validation: Whether it's validation dataset or not.
|
38
|
+
device: Device to transfer the data to.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Generator to use for retrieving the dataset inputs.
|
42
|
+
|
43
|
+
"""
|
44
|
+
|
45
|
+
def process_data(data: Any, is_validation: bool, device: str):
|
46
|
+
"""
|
47
|
+
Processes individual data samples: Transfer them to the device, convert to torch tensors if needed,
|
48
|
+
remove labels if this is a validation dataset.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
data: The data sample to process.
|
52
|
+
is_validation: A flag indicating if this is a validation dataset.
|
53
|
+
device: The device to transfer the data to.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The data as torch tensors on the desired device.
|
57
|
+
"""
|
58
|
+
|
59
|
+
def transfer_to_device(_data):
|
60
|
+
if isinstance(_data, np.ndarray):
|
61
|
+
return torch.from_numpy(_data).to(device)
|
62
|
+
return _data.to(device)
|
63
|
+
|
64
|
+
if is_validation:
|
65
|
+
inputs = data[0] # Assume data[0] contains the inputs and data[1] the labels
|
66
|
+
if isinstance(inputs, list):
|
67
|
+
data = [transfer_to_device(t) for t in inputs]
|
68
|
+
else:
|
69
|
+
data = [transfer_to_device(inputs)]
|
70
|
+
else:
|
71
|
+
data = [transfer_to_device(t) for t in data]
|
72
|
+
|
73
|
+
return data
|
74
|
+
|
75
|
+
for x in dataset():
|
76
|
+
yield process_data(x, is_validation, device)
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable
|
17
|
+
|
18
|
+
from model_compression_toolkit.constants import FOUND_TORCH
|
19
|
+
from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
|
20
|
+
from model_compression_toolkit.xquant import XQuantConfig
|
21
|
+
from model_compression_toolkit.logger import Logger
|
22
|
+
|
23
|
+
if FOUND_TORCH:
|
24
|
+
from model_compression_toolkit.xquant.pytorch.pytorch_report_utils import PytorchReportUtils
|
25
|
+
import torch
|
26
|
+
|
27
|
+
def xquant_report_pytorch_experimental(float_model: torch.nn.Module,
|
28
|
+
quantized_model: torch.nn.Module,
|
29
|
+
repr_dataset: Callable,
|
30
|
+
validation_dataset: Callable,
|
31
|
+
xquant_config: XQuantConfig):
|
32
|
+
"""
|
33
|
+
Generate an explainable quantization report for a quantized Pytorch model.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
float_model (torch.nn.Module): The original floating-point Pytorch model.
|
37
|
+
quantized_model (torch.nn.Module): The quantized Pytorch model.
|
38
|
+
repr_dataset (Callable): The representative dataset used during quantization.
|
39
|
+
validation_dataset (Callable): The validation dataset used for evaluation.
|
40
|
+
xquant_config (XQuantConfig): Configuration settings for explainable quantization.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
|
44
|
+
"""
|
45
|
+
# Initialize the logger with the report directory.
|
46
|
+
Logger.set_log_file(log_folder=xquant_config.report_dir)
|
47
|
+
|
48
|
+
pytorch_report_utils = PytorchReportUtils(xquant_config.report_dir)
|
49
|
+
|
50
|
+
_collected_data = core_report_generator(float_model=float_model,
|
51
|
+
quantized_model=quantized_model,
|
52
|
+
repr_dataset=repr_dataset,
|
53
|
+
validation_dataset=validation_dataset,
|
54
|
+
fw_report_utils=pytorch_report_utils,
|
55
|
+
xquant_config=xquant_config)
|
56
|
+
|
57
|
+
return _collected_data
|
58
|
+
|
59
|
+
else:
|
60
|
+
def xquant_report_pytorch_experimental(*args, **kwargs):
|
61
|
+
Logger.critical("PyTorch must be installed to use 'xquant_report_pytorch_experimental'. "
|
62
|
+
"The 'torch' package is missing.") # pragma: no cover
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import Dict, List, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from mct_quantizers.pytorch.quantize_wrapper import PytorchQuantizationWrapper
|
19
|
+
from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
|
20
|
+
|
21
|
+
from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
|
22
|
+
|
23
|
+
|
24
|
+
class PytorchModelAnalyzer(ModelAnalyzer):
|
25
|
+
"""
|
26
|
+
This class provides utilities for analyzing Pytorch models, specifically for
|
27
|
+
extracting activations and comparing float and quantized models.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def extract_model_activations(self,
|
31
|
+
float_model: torch.nn.Module,
|
32
|
+
quantized_model: torch.nn.Module,
|
33
|
+
float_name2quant_name: Dict[str, str],
|
34
|
+
data: List[torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
35
|
+
"""
|
36
|
+
Extracts activations from both the float and quantized models.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
float_model (torch.nn.Module): The float model.
|
40
|
+
quantized_model (torch.nn.Module): The quantized model.
|
41
|
+
float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
|
42
|
+
names.
|
43
|
+
data (List[torch.Tensor]): Input data for which to compute activations.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
47
|
+
- Dictionary of activations for the float model.
|
48
|
+
- Dictionary of activations for the quantized model.
|
49
|
+
"""
|
50
|
+
|
51
|
+
def _compute_activations(name: str, activations: dict):
|
52
|
+
"""
|
53
|
+
Creates a hook function to capture the activations of a layer.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
name (str): The name of the layer.
|
57
|
+
activations (dict): The dictionary to store the activations.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
hook (function): The hook function to register with the layer.
|
61
|
+
"""
|
62
|
+
def hook(model, input, output):
|
63
|
+
activations[name] = output.detach()
|
64
|
+
|
65
|
+
return hook
|
66
|
+
|
67
|
+
# Initialize dictionaries to store activations for both models
|
68
|
+
activations_float = {}
|
69
|
+
activations_quant = {}
|
70
|
+
|
71
|
+
# Register hooks for all layers in the float model
|
72
|
+
for layer_name in float_name2quant_name.keys():
|
73
|
+
layer = dict([*float_model.named_modules()])[layer_name]
|
74
|
+
layer.register_forward_hook(_compute_activations(layer_name, activations_float))
|
75
|
+
|
76
|
+
# Register hooks for all layers in the quantized model
|
77
|
+
for layer_name in float_name2quant_name.values():
|
78
|
+
layer = dict([*quantized_model.named_modules()])[layer_name]
|
79
|
+
layer.register_forward_hook(_compute_activations(layer_name, activations_quant))
|
80
|
+
|
81
|
+
# Perform a forward pass with the input data and capture activations
|
82
|
+
with torch.no_grad():
|
83
|
+
float_predictions = float_model(*data)
|
84
|
+
quant_predictions = quantized_model(*data)
|
85
|
+
|
86
|
+
activations_float[MODEL_OUTPUT_KEY] = float_predictions
|
87
|
+
activations_quant[MODEL_OUTPUT_KEY] = quant_predictions
|
88
|
+
|
89
|
+
return activations_float, activations_quant
|
90
|
+
|
91
|
+
def identify_quantized_compare_points(self,
|
92
|
+
quantized_model: torch.nn.Module) -> List[str]:
|
93
|
+
"""
|
94
|
+
Identifies points in the quantized model to compare with the float model.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
quantized_model (torch.nn.Module): The quantized model.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
List[str]: A list of layer names in the quantized model to compare.
|
101
|
+
"""
|
102
|
+
return [n for n, m in quantized_model.named_modules() if isinstance(m, PytorchQuantizationWrapper)]
|
103
|
+
|
104
|
+
def find_corresponding_float_layer(self,
|
105
|
+
quant_compare_point: str,
|
106
|
+
quantized_model: torch.nn.Module) -> str:
|
107
|
+
"""
|
108
|
+
Finds the corresponding float model layer for a given quantized model layer.
|
109
|
+
In pytorch, we assume the name is the same in the float model, thus we return quant_compare_point.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
quant_compare_point (str): The name of the layer in the quantized model.
|
113
|
+
quantized_model (torch.nn.Module): The quantized model.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
str: The name of the corresponding layer in the float model.
|
117
|
+
"""
|
118
|
+
return quant_compare_point
|
119
|
+
|
120
|
+
def extract_float_layer_names(self,
|
121
|
+
float_model: torch.nn.Module) -> List[str]:
|
122
|
+
"""
|
123
|
+
Extracts the names of all layers in the float model.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
float_model (torch.nn.Module): The float model.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
List[str]: A list of layer names in the float model.
|
130
|
+
"""
|
131
|
+
return [n for n, m in float_model.named_modules()]
|
132
|
+
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from model_compression_toolkit.core.pytorch.utils import get_working_device
|
16
|
+
|
17
|
+
from model_compression_toolkit.ptq.pytorch.quantization_facade import DEFAULT_PYTORCH_TPC
|
18
|
+
from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
|
19
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
20
|
+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
21
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
22
|
+
from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
|
23
|
+
from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
|
24
|
+
from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
|
25
|
+
from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
|
26
|
+
from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
|
27
|
+
|
28
|
+
|
29
|
+
class PytorchReportUtils(FrameworkReportUtils):
|
30
|
+
"""
|
31
|
+
Class with various utility components required for generating the report for a Pytorch model.
|
32
|
+
"""
|
33
|
+
def __init__(self, report_dir: str):
|
34
|
+
"""
|
35
|
+
Args:
|
36
|
+
report_dir: Logging dir path.
|
37
|
+
"""
|
38
|
+
fw_info = DEFAULT_PYTORCH_INFO
|
39
|
+
fw_impl = PytorchImplementation()
|
40
|
+
|
41
|
+
dataset_utils = PytorchDatasetUtils()
|
42
|
+
model_folding = ModelFoldingUtils(fw_info=fw_info,
|
43
|
+
fw_impl=fw_impl,
|
44
|
+
fw_default_tpc=DEFAULT_PYTORCH_TPC)
|
45
|
+
|
46
|
+
similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
|
47
|
+
model_folding=model_folding,
|
48
|
+
similarity_functions=PytorchSimilarityFunctions(),
|
49
|
+
model_analyzer_utils=PytorchModelAnalyzer(),
|
50
|
+
device=get_working_device())
|
51
|
+
|
52
|
+
tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
|
53
|
+
fw_impl=fw_impl,
|
54
|
+
fw_info=fw_info)
|
55
|
+
|
56
|
+
super().__init__(fw_info=fw_info,
|
57
|
+
fw_impl=fw_impl,
|
58
|
+
tb_utils=tb_utils,
|
59
|
+
dataset_utils=dataset_utils,
|
60
|
+
similarity_calculator=similarity_calculator,
|
61
|
+
model_folding_utils=model_folding)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
|
18
|
+
import torch
|
19
|
+
|
20
|
+
class PytorchSimilarityFunctions(SimilarityFunctions):
|
21
|
+
|
22
|
+
@staticmethod
|
23
|
+
def compute_mse(x: torch.Tensor, y: torch.Tensor) -> float:
|
24
|
+
"""
|
25
|
+
Computes Mean Squared Error between between two tensors (usually, the float and quantized tensors).
|
26
|
+
|
27
|
+
Args:
|
28
|
+
x: Float model predictions.
|
29
|
+
y: Quantized model predictions.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Mean Squared Error as a float.
|
33
|
+
"""
|
34
|
+
mse = torch.nn.functional.mse_loss(x, y)
|
35
|
+
return mse.item()
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def compute_cs(x: torch.Tensor, y: torch.Tensor) -> float:
|
39
|
+
"""
|
40
|
+
Computes Cosine Similarity between two tensors (usually, the float and quantized tensors).
|
41
|
+
|
42
|
+
Args:
|
43
|
+
x: Float model predictions.
|
44
|
+
y: Quantized model predictions.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Cosine Similarity as a float.
|
48
|
+
"""
|
49
|
+
cs = torch.nn.functional.cosine_similarity(x.flatten(), y.flatten(), dim=0)
|
50
|
+
return cs.item()
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
|
54
|
+
"""
|
55
|
+
Computes Signal-to-Quantization-Noise Ratio between two tensors (usually, the float and quantized tensors).
|
56
|
+
|
57
|
+
Args:
|
58
|
+
x: Float model predictions.
|
59
|
+
y: Quantized model predictions.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Signal-to-Quantization-Noise Ratio as a float.
|
63
|
+
"""
|
64
|
+
signal_power = torch.mean(x ** 2)
|
65
|
+
noise_power = torch.mean((x - y) ** 2)
|
66
|
+
sqnr = signal_power / noise_power
|
67
|
+
return sqnr.item()
|
68
|
+
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
17
|
+
from typing import Dict, Any, Callable
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
|
+
from model_compression_toolkit.core.pytorch.reader.reader import model_reader
|
23
|
+
from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
|
24
|
+
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
25
|
+
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
26
|
+
|
27
|
+
class PytorchTensorboardUtils(TensorboardUtils):
|
28
|
+
"""
|
29
|
+
Utility class for handling PyTorch models with TensorBoard. Inherits from TensorboardUtils.
|
30
|
+
This class provides functionalities to display quantized model graphs on TensorBoard.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self,
|
34
|
+
report_dir: str,
|
35
|
+
fw_info: FrameworkInfo,
|
36
|
+
fw_impl: FrameworkImplementation):
|
37
|
+
"""
|
38
|
+
Initialize the PytorchTensorboardUtils instance.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
report_dir: Directory where the reports are stored.
|
42
|
+
fw_info: Information about the framework being used.
|
43
|
+
fw_impl: Implementation methods for the framework.
|
44
|
+
"""
|
45
|
+
super().__init__(report_dir,
|
46
|
+
fw_info,
|
47
|
+
fw_impl)
|
48
|
+
|
49
|
+
def get_graph_for_tensorboard_display(self,
|
50
|
+
quantized_model: torch.nn.Module,
|
51
|
+
similarity_metrics: Dict[str, Any],
|
52
|
+
repr_dataset: Callable):
|
53
|
+
"""
|
54
|
+
Get the graph to display on TensorBoard. The graph represents the quantized model
|
55
|
+
with the similarity metrics that were measured.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
quantized_model: The quantized model to be displayed on TensorBoard.
|
59
|
+
similarity_metrics: Dictionary containing the collected similarity metrics values.
|
60
|
+
repr_dataset: Callable that generates the representative dataset used during graph building.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The updated quantized model graph with similarity metrics embedded.
|
64
|
+
"""
|
65
|
+
# Read the model and generate a graph representation
|
66
|
+
quant_graph = model_reader(quantized_model,
|
67
|
+
representative_data_gen=repr_dataset,
|
68
|
+
to_tensor=self.fw_impl.to_tensor,
|
69
|
+
to_numpy=self.fw_impl.to_numpy)
|
70
|
+
|
71
|
+
# Iterate through each node in the graph
|
72
|
+
for node in quant_graph.nodes:
|
73
|
+
# Check and add similarity metrics for each node in the graph
|
74
|
+
if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
|
75
|
+
node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][f"{node.name}"]
|
76
|
+
elif node.name.removesuffix("_layer") in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
|
77
|
+
node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][
|
78
|
+
node.name.removesuffix("_layer")]
|
79
|
+
|
80
|
+
# Check and add validation similarity metrics for each node in the graph
|
81
|
+
if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
|
82
|
+
node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][f"{node.name}"]
|
83
|
+
elif node.name.removesuffix("_layer") in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
|
84
|
+
node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][
|
85
|
+
node.name.removesuffix("_layer")]
|
86
|
+
|
87
|
+
return quant_graph
|
{mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md
RENAMED
File without changes
|
{mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|