mct-nightly 2.1.0.20240616.65727__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +43 -17
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/functional_node.py +3 -3
  5. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
  6. model_compression_toolkit/core/pytorch/constants.py +1 -1
  7. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
  8. model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
  9. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  10. model_compression_toolkit/gptq/keras/graph_info.py +1 -1
  11. model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
  12. model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
  13. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
  15. model_compression_toolkit/xquant/__init__.py +19 -0
  16. model_compression_toolkit/xquant/common/__init__.py +15 -0
  17. model_compression_toolkit/xquant/common/constants.py +38 -0
  18. model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
  19. model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
  20. model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
  21. model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
  22. model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
  23. model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
  24. model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
  25. model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
  26. model_compression_toolkit/xquant/common/xquant_config.py +39 -0
  27. model_compression_toolkit/xquant/keras/__init__.py +15 -0
  28. model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
  29. model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
  30. model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
  31. model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
  32. model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
  33. model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
  34. model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
  35. model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
  36. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
  37. model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
  38. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
  39. model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
  40. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
  41. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
  42. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
  43. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any, Callable
17
+
18
+ from model_compression_toolkit.xquant.common.dataset_utils import DatasetUtils
19
+ import numpy as np
20
+
21
+ import torch
22
+
23
+
24
+ class PytorchDatasetUtils(DatasetUtils):
25
+ """
26
+ Class with helpful methods for handling different kinds of Pytorch datasets from the user.
27
+ """
28
+ @staticmethod
29
+ def prepare_dataset(dataset: Callable, is_validation: bool, device: str = None):
30
+ """
31
+ Prepare the dataset so calling it will return only inputs for the model (like in the case
32
+ of the representative dataset). For example, when the validation dataset is used, the labels
33
+ should be removed.
34
+
35
+ Args:
36
+ dataset: Dataset to prepare.
37
+ is_validation: Whether it's validation dataset or not.
38
+ device: Device to transfer the data to.
39
+
40
+ Returns:
41
+ Generator to use for retrieving the dataset inputs.
42
+
43
+ """
44
+
45
+ def process_data(data: Any, is_validation: bool, device: str):
46
+ """
47
+ Processes individual data samples: Transfer them to the device, convert to torch tensors if needed,
48
+ remove labels if this is a validation dataset.
49
+
50
+ Args:
51
+ data: The data sample to process.
52
+ is_validation: A flag indicating if this is a validation dataset.
53
+ device: The device to transfer the data to.
54
+
55
+ Returns:
56
+ The data as torch tensors on the desired device.
57
+ """
58
+
59
+ def transfer_to_device(_data):
60
+ if isinstance(_data, np.ndarray):
61
+ return torch.from_numpy(_data).to(device)
62
+ return _data.to(device)
63
+
64
+ if is_validation:
65
+ inputs = data[0] # Assume data[0] contains the inputs and data[1] the labels
66
+ if isinstance(inputs, list):
67
+ data = [transfer_to_device(t) for t in inputs]
68
+ else:
69
+ data = [transfer_to_device(inputs)]
70
+ else:
71
+ data = [transfer_to_device(t) for t in data]
72
+
73
+ return data
74
+
75
+ for x in dataset():
76
+ yield process_data(x, is_validation, device)
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable
17
+
18
+ from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.xquant.common.core_report_generator import core_report_generator
20
+ from model_compression_toolkit.xquant import XQuantConfig
21
+ from model_compression_toolkit.logger import Logger
22
+
23
+ if FOUND_TORCH:
24
+ from model_compression_toolkit.xquant.pytorch.pytorch_report_utils import PytorchReportUtils
25
+ import torch
26
+
27
+ def xquant_report_pytorch_experimental(float_model: torch.nn.Module,
28
+ quantized_model: torch.nn.Module,
29
+ repr_dataset: Callable,
30
+ validation_dataset: Callable,
31
+ xquant_config: XQuantConfig):
32
+ """
33
+ Generate an explainable quantization report for a quantized Pytorch model.
34
+
35
+ Args:
36
+ float_model (torch.nn.Module): The original floating-point Pytorch model.
37
+ quantized_model (torch.nn.Module): The quantized Pytorch model.
38
+ repr_dataset (Callable): The representative dataset used during quantization.
39
+ validation_dataset (Callable): The validation dataset used for evaluation.
40
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
41
+
42
+ Returns:
43
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
44
+ """
45
+ # Initialize the logger with the report directory.
46
+ Logger.set_log_file(log_folder=xquant_config.report_dir)
47
+
48
+ pytorch_report_utils = PytorchReportUtils(xquant_config.report_dir)
49
+
50
+ _collected_data = core_report_generator(float_model=float_model,
51
+ quantized_model=quantized_model,
52
+ repr_dataset=repr_dataset,
53
+ validation_dataset=validation_dataset,
54
+ fw_report_utils=pytorch_report_utils,
55
+ xquant_config=xquant_config)
56
+
57
+ return _collected_data
58
+
59
+ else:
60
+ def xquant_report_pytorch_experimental(*args, **kwargs):
61
+ Logger.critical("PyTorch must be installed to use 'xquant_report_pytorch_experimental'. "
62
+ "The 'torch' package is missing.") # pragma: no cover
@@ -0,0 +1,132 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Dict, List, Tuple
16
+
17
+ import torch
18
+ from mct_quantizers.pytorch.quantize_wrapper import PytorchQuantizationWrapper
19
+ from model_compression_toolkit.xquant.common.constants import MODEL_OUTPUT_KEY
20
+
21
+ from model_compression_toolkit.xquant.common.model_analyzer import ModelAnalyzer
22
+
23
+
24
+ class PytorchModelAnalyzer(ModelAnalyzer):
25
+ """
26
+ This class provides utilities for analyzing Pytorch models, specifically for
27
+ extracting activations and comparing float and quantized models.
28
+ """
29
+
30
+ def extract_model_activations(self,
31
+ float_model: torch.nn.Module,
32
+ quantized_model: torch.nn.Module,
33
+ float_name2quant_name: Dict[str, str],
34
+ data: List[torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
35
+ """
36
+ Extracts activations from both the float and quantized models.
37
+
38
+ Args:
39
+ float_model (torch.nn.Module): The float model.
40
+ quantized_model (torch.nn.Module): The quantized model.
41
+ float_name2quant_name (Dict[str, str]): A mapping from float model layer names to quantized model layer
42
+ names.
43
+ data (List[torch.Tensor]): Input data for which to compute activations.
44
+
45
+ Returns:
46
+ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
47
+ - Dictionary of activations for the float model.
48
+ - Dictionary of activations for the quantized model.
49
+ """
50
+
51
+ def _compute_activations(name: str, activations: dict):
52
+ """
53
+ Creates a hook function to capture the activations of a layer.
54
+
55
+ Args:
56
+ name (str): The name of the layer.
57
+ activations (dict): The dictionary to store the activations.
58
+
59
+ Returns:
60
+ hook (function): The hook function to register with the layer.
61
+ """
62
+ def hook(model, input, output):
63
+ activations[name] = output.detach()
64
+
65
+ return hook
66
+
67
+ # Initialize dictionaries to store activations for both models
68
+ activations_float = {}
69
+ activations_quant = {}
70
+
71
+ # Register hooks for all layers in the float model
72
+ for layer_name in float_name2quant_name.keys():
73
+ layer = dict([*float_model.named_modules()])[layer_name]
74
+ layer.register_forward_hook(_compute_activations(layer_name, activations_float))
75
+
76
+ # Register hooks for all layers in the quantized model
77
+ for layer_name in float_name2quant_name.values():
78
+ layer = dict([*quantized_model.named_modules()])[layer_name]
79
+ layer.register_forward_hook(_compute_activations(layer_name, activations_quant))
80
+
81
+ # Perform a forward pass with the input data and capture activations
82
+ with torch.no_grad():
83
+ float_predictions = float_model(*data)
84
+ quant_predictions = quantized_model(*data)
85
+
86
+ activations_float[MODEL_OUTPUT_KEY] = float_predictions
87
+ activations_quant[MODEL_OUTPUT_KEY] = quant_predictions
88
+
89
+ return activations_float, activations_quant
90
+
91
+ def identify_quantized_compare_points(self,
92
+ quantized_model: torch.nn.Module) -> List[str]:
93
+ """
94
+ Identifies points in the quantized model to compare with the float model.
95
+
96
+ Args:
97
+ quantized_model (torch.nn.Module): The quantized model.
98
+
99
+ Returns:
100
+ List[str]: A list of layer names in the quantized model to compare.
101
+ """
102
+ return [n for n, m in quantized_model.named_modules() if isinstance(m, PytorchQuantizationWrapper)]
103
+
104
+ def find_corresponding_float_layer(self,
105
+ quant_compare_point: str,
106
+ quantized_model: torch.nn.Module) -> str:
107
+ """
108
+ Finds the corresponding float model layer for a given quantized model layer.
109
+ In pytorch, we assume the name is the same in the float model, thus we return quant_compare_point.
110
+
111
+ Args:
112
+ quant_compare_point (str): The name of the layer in the quantized model.
113
+ quantized_model (torch.nn.Module): The quantized model.
114
+
115
+ Returns:
116
+ str: The name of the corresponding layer in the float model.
117
+ """
118
+ return quant_compare_point
119
+
120
+ def extract_float_layer_names(self,
121
+ float_model: torch.nn.Module) -> List[str]:
122
+ """
123
+ Extracts the names of all layers in the float model.
124
+
125
+ Args:
126
+ float_model (torch.nn.Module): The float model.
127
+
128
+ Returns:
129
+ List[str]: A list of layer names in the float model.
130
+ """
131
+ return [n for n, m in float_model.named_modules()]
132
+
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from model_compression_toolkit.core.pytorch.utils import get_working_device
16
+
17
+ from model_compression_toolkit.ptq.pytorch.quantization_facade import DEFAULT_PYTORCH_TPC
18
+ from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
19
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
20
+ from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
21
+ from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
22
+ from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
23
+ from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatasetUtils
24
+ from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
25
+ from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
26
+ from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
27
+
28
+
29
+ class PytorchReportUtils(FrameworkReportUtils):
30
+ """
31
+ Class with various utility components required for generating the report for a Pytorch model.
32
+ """
33
+ def __init__(self, report_dir: str):
34
+ """
35
+ Args:
36
+ report_dir: Logging dir path.
37
+ """
38
+ fw_info = DEFAULT_PYTORCH_INFO
39
+ fw_impl = PytorchImplementation()
40
+
41
+ dataset_utils = PytorchDatasetUtils()
42
+ model_folding = ModelFoldingUtils(fw_info=fw_info,
43
+ fw_impl=fw_impl,
44
+ fw_default_tpc=DEFAULT_PYTORCH_TPC)
45
+
46
+ similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
47
+ model_folding=model_folding,
48
+ similarity_functions=PytorchSimilarityFunctions(),
49
+ model_analyzer_utils=PytorchModelAnalyzer(),
50
+ device=get_working_device())
51
+
52
+ tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
53
+ fw_impl=fw_impl,
54
+ fw_info=fw_info)
55
+
56
+ super().__init__(fw_info=fw_info,
57
+ fw_impl=fw_impl,
58
+ tb_utils=tb_utils,
59
+ dataset_utils=dataset_utils,
60
+ similarity_calculator=similarity_calculator,
61
+ model_folding_utils=model_folding)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from model_compression_toolkit.xquant.common.similarity_functions import SimilarityFunctions
18
+ import torch
19
+
20
+ class PytorchSimilarityFunctions(SimilarityFunctions):
21
+
22
+ @staticmethod
23
+ def compute_mse(x: torch.Tensor, y: torch.Tensor) -> float:
24
+ """
25
+ Computes Mean Squared Error between between two tensors (usually, the float and quantized tensors).
26
+
27
+ Args:
28
+ x: Float model predictions.
29
+ y: Quantized model predictions.
30
+
31
+ Returns:
32
+ Mean Squared Error as a float.
33
+ """
34
+ mse = torch.nn.functional.mse_loss(x, y)
35
+ return mse.item()
36
+
37
+ @staticmethod
38
+ def compute_cs(x: torch.Tensor, y: torch.Tensor) -> float:
39
+ """
40
+ Computes Cosine Similarity between two tensors (usually, the float and quantized tensors).
41
+
42
+ Args:
43
+ x: Float model predictions.
44
+ y: Quantized model predictions.
45
+
46
+ Returns:
47
+ Cosine Similarity as a float.
48
+ """
49
+ cs = torch.nn.functional.cosine_similarity(x.flatten(), y.flatten(), dim=0)
50
+ return cs.item()
51
+
52
+ @staticmethod
53
+ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
54
+ """
55
+ Computes Signal-to-Quantization-Noise Ratio between two tensors (usually, the float and quantized tensors).
56
+
57
+ Args:
58
+ x: Float model predictions.
59
+ y: Quantized model predictions.
60
+
61
+ Returns:
62
+ Signal-to-Quantization-Noise Ratio as a float.
63
+ """
64
+ signal_power = torch.mean(x ** 2)
65
+ noise_power = torch.mean((x - y) ** 2)
66
+ sqnr = signal_power / noise_power
67
+ return sqnr.item()
68
+
@@ -0,0 +1,87 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
17
+ from typing import Dict, Any, Callable
18
+
19
+ import torch
20
+
21
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
+ from model_compression_toolkit.core.pytorch.reader.reader import model_reader
23
+ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
24
+ from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
25
+ from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
26
+
27
+ class PytorchTensorboardUtils(TensorboardUtils):
28
+ """
29
+ Utility class for handling PyTorch models with TensorBoard. Inherits from TensorboardUtils.
30
+ This class provides functionalities to display quantized model graphs on TensorBoard.
31
+ """
32
+
33
+ def __init__(self,
34
+ report_dir: str,
35
+ fw_info: FrameworkInfo,
36
+ fw_impl: FrameworkImplementation):
37
+ """
38
+ Initialize the PytorchTensorboardUtils instance.
39
+
40
+ Args:
41
+ report_dir: Directory where the reports are stored.
42
+ fw_info: Information about the framework being used.
43
+ fw_impl: Implementation methods for the framework.
44
+ """
45
+ super().__init__(report_dir,
46
+ fw_info,
47
+ fw_impl)
48
+
49
+ def get_graph_for_tensorboard_display(self,
50
+ quantized_model: torch.nn.Module,
51
+ similarity_metrics: Dict[str, Any],
52
+ repr_dataset: Callable):
53
+ """
54
+ Get the graph to display on TensorBoard. The graph represents the quantized model
55
+ with the similarity metrics that were measured.
56
+
57
+ Args:
58
+ quantized_model: The quantized model to be displayed on TensorBoard.
59
+ similarity_metrics: Dictionary containing the collected similarity metrics values.
60
+ repr_dataset: Callable that generates the representative dataset used during graph building.
61
+
62
+ Returns:
63
+ The updated quantized model graph with similarity metrics embedded.
64
+ """
65
+ # Read the model and generate a graph representation
66
+ quant_graph = model_reader(quantized_model,
67
+ representative_data_gen=repr_dataset,
68
+ to_tensor=self.fw_impl.to_tensor,
69
+ to_numpy=self.fw_impl.to_numpy)
70
+
71
+ # Iterate through each node in the graph
72
+ for node in quant_graph.nodes:
73
+ # Check and add similarity metrics for each node in the graph
74
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
75
+ node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][f"{node.name}"]
76
+ elif node.name.removesuffix("_layer") in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR].keys():
77
+ node.framework_attr[XQUANT_REPR] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_REPR][
78
+ node.name.removesuffix("_layer")]
79
+
80
+ # Check and add validation similarity metrics for each node in the graph
81
+ if node.name in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
82
+ node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][f"{node.name}"]
83
+ elif node.name.removesuffix("_layer") in similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL].keys():
84
+ node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][
85
+ node.name.removesuffix("_layer")]
86
+
87
+ return quant_graph