mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
from msprobe.core.config_check.utils.utils import config_checking_print
|
|
20
|
+
from msprobe.core.common.file_utils import FileOpen, load_yaml
|
|
21
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Parser(ABC):
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def parse(self, file_path: str) -> dict:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def run(self, file_path: str) -> dict:
|
|
30
|
+
"""
|
|
31
|
+
统一对外调用接口
|
|
32
|
+
:param file_path: 需解析的文件路径
|
|
33
|
+
:return:
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
result = self.parse(file_path)
|
|
37
|
+
except Exception as exc:
|
|
38
|
+
config_checking_print(f"{self.__class__} parsing error, skip file path: {file_path}, error: {exc}")
|
|
39
|
+
result = {}
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ShellParser(Parser):
|
|
44
|
+
def parse(self, file_path: str) -> dict:
|
|
45
|
+
"""
|
|
46
|
+
Extracts arguments from bash script used to run a model training.
|
|
47
|
+
"""
|
|
48
|
+
hyperparameters = {}
|
|
49
|
+
script_content_list = []
|
|
50
|
+
with FileOpen(file_path, 'r') as file:
|
|
51
|
+
for line in file:
|
|
52
|
+
stripped_line = line.lstrip()
|
|
53
|
+
if not stripped_line.startswith('#'):
|
|
54
|
+
line = line.split('#')[0].rstrip() + '\n'
|
|
55
|
+
if line.strip():
|
|
56
|
+
script_content_list.append(line)
|
|
57
|
+
script_content = ''.join(script_content_list)
|
|
58
|
+
|
|
59
|
+
command_line = re.search(r'msrun\s[^|]*|torchrun\s[^|]*|python\d? -m torch.distributed.launch\s[^|]*',
|
|
60
|
+
script_content,
|
|
61
|
+
re.DOTALL)
|
|
62
|
+
if command_line:
|
|
63
|
+
command_line = command_line.group()
|
|
64
|
+
|
|
65
|
+
blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL)
|
|
66
|
+
block_contents = {}
|
|
67
|
+
for block_name, block_content in blocks:
|
|
68
|
+
block_content = block_content.replace('\n', ' ')
|
|
69
|
+
block_contents[block_name] = block_content
|
|
70
|
+
command_line = command_line.replace(f"${block_name}", block_content)
|
|
71
|
+
|
|
72
|
+
matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line)
|
|
73
|
+
for match in matches:
|
|
74
|
+
key, value = match
|
|
75
|
+
args_key = re.match(r'\$\{?(\w+)}?', value)
|
|
76
|
+
if args_key:
|
|
77
|
+
env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content)
|
|
78
|
+
if env_vars:
|
|
79
|
+
value = env_vars[-1]
|
|
80
|
+
hyperparameters[key] = value if value else True
|
|
81
|
+
|
|
82
|
+
return hyperparameters
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class YamlParser(Parser):
|
|
86
|
+
hyperparameters = {}
|
|
87
|
+
|
|
88
|
+
def parse(self, file_path: str) -> dict:
|
|
89
|
+
ori_hyper = load_yaml(file_path)
|
|
90
|
+
self.recursive_parse_parameters(ori_hyper, "")
|
|
91
|
+
return self.hyperparameters
|
|
92
|
+
|
|
93
|
+
def recursive_parse_parameters(self, parameters, prefix):
|
|
94
|
+
if isinstance(parameters, dict):
|
|
95
|
+
for key, value in parameters.items():
|
|
96
|
+
new_prefix = prefix + Const.SEP + key if prefix else key
|
|
97
|
+
self.recursive_parse_parameters(value, new_prefix)
|
|
98
|
+
elif isinstance(parameters, list):
|
|
99
|
+
for value in parameters:
|
|
100
|
+
self.recursive_parse_parameters(value, prefix)
|
|
101
|
+
elif isinstance(parameters, (int, str, bool)):
|
|
102
|
+
self.hyperparameters.update({prefix: parameters})
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ParserFactory:
|
|
106
|
+
__ParserDict = {
|
|
107
|
+
FileCheckConst.SHELL_SUFFIX: ShellParser(),
|
|
108
|
+
FileCheckConst.YAML_SUFFIX: YamlParser()
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
def get_parser(self, file_type: str) -> Parser:
|
|
112
|
+
parser = self.__ParserDict.get(file_type, None)
|
|
113
|
+
if not parser:
|
|
114
|
+
raise ValueError(f'Invalid parser type: {file_type}')
|
|
115
|
+
return parser
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import hashlib
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def merge_keys(dir_0, dir_1):
|
|
25
|
+
output_list = list(dir_0.keys())
|
|
26
|
+
output_list.extend(list(dir_1.keys()))
|
|
27
|
+
return set(output_list)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compare_dict(bench_dict, cmp_dict):
|
|
31
|
+
result = []
|
|
32
|
+
for key in set(bench_dict.keys()) | set(cmp_dict.keys()):
|
|
33
|
+
if key in bench_dict and key in cmp_dict:
|
|
34
|
+
if bench_dict[key] != cmp_dict[key]:
|
|
35
|
+
result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}")
|
|
36
|
+
elif key in bench_dict:
|
|
37
|
+
result.append(f"{key}: [deleted] -> {bench_dict[key]}")
|
|
38
|
+
else:
|
|
39
|
+
result.append(f"{key}: [added] -> {cmp_dict[key]}")
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def config_checking_print(msg):
|
|
44
|
+
logger.info(f"[config checking log] {msg}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def tensor_to_hash(tensor):
|
|
48
|
+
"""Compute the hash value of a tensor"""
|
|
49
|
+
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
|
50
|
+
return bytes_hash(tensor_bytes)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_tensor_features(tensor):
|
|
54
|
+
features = {
|
|
55
|
+
"max": FmkAdp.tensor_max(tensor),
|
|
56
|
+
"min": FmkAdp.tensor_min(tensor),
|
|
57
|
+
"mean": FmkAdp.tensor_mean(tensor),
|
|
58
|
+
"norm": FmkAdp.tensor_norm(tensor),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
return features
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def compare_dicts(dict1, dict2, path=''):
|
|
65
|
+
deleted = []
|
|
66
|
+
added = []
|
|
67
|
+
changed = []
|
|
68
|
+
result = {}
|
|
69
|
+
|
|
70
|
+
for key in dict1:
|
|
71
|
+
if key not in dict2:
|
|
72
|
+
deleted.append(f"[Deleted]: {path + key}")
|
|
73
|
+
result[key] = "[deleted]"
|
|
74
|
+
else:
|
|
75
|
+
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
|
76
|
+
sub_deleted, sub_added, sub_changed, sub_result = compare_dicts(
|
|
77
|
+
dict1[key], dict2[key], path + key + '/')
|
|
78
|
+
deleted.extend(sub_deleted)
|
|
79
|
+
added.extend(sub_added)
|
|
80
|
+
changed.extend(sub_changed)
|
|
81
|
+
if sub_result:
|
|
82
|
+
result[key] = sub_result
|
|
83
|
+
elif dict1[key] != dict2[key]:
|
|
84
|
+
changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}")
|
|
85
|
+
result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}"
|
|
86
|
+
for key in dict2:
|
|
87
|
+
if key not in dict1:
|
|
88
|
+
added.append(f"[Added]: {path + key}")
|
|
89
|
+
result[key] = "[added]"
|
|
90
|
+
return deleted, added, changed, result
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def bytes_hash(obj: bytes):
|
|
94
|
+
hex_dig = hashlib.sha256(obj).hexdigest()
|
|
95
|
+
short_hash = int(hex_dig, 16) % (2 ** 16)
|
|
96
|
+
return short_hash
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def update_dict(ori_dict, new_dict):
|
|
100
|
+
for key, value in new_dict.items():
|
|
101
|
+
if key in ori_dict and ori_dict[key] != value:
|
|
102
|
+
if "values" in ori_dict.keys():
|
|
103
|
+
ori_dict[key]["values"].append(new_dict[key])
|
|
104
|
+
else:
|
|
105
|
+
ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
|
|
106
|
+
else:
|
|
107
|
+
ori_dict[key] = value
|
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import inspect
|
|
16
17
|
from typing import Dict, Any, Optional, Callable, Union, List, Tuple
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import Const
|
|
19
20
|
from msprobe.core.common.file_utils import load_yaml
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
def _get_attr(module, attr_name):
|
|
@@ -32,7 +34,8 @@ def _get_attr(module, attr_name):
|
|
|
32
34
|
class ApiWrapper:
|
|
33
35
|
def __init__(
|
|
34
36
|
self, api_types: Dict[str, Dict[str, Any]],
|
|
35
|
-
api_list_paths: Union[str, List[str], Tuple[str]]
|
|
37
|
+
api_list_paths: Union[str, List[str], Tuple[str]],
|
|
38
|
+
backlist: Union[List[str], Tuple[str]] = None
|
|
36
39
|
):
|
|
37
40
|
self.api_types = api_types
|
|
38
41
|
if not isinstance(api_list_paths, (list, tuple)):
|
|
@@ -41,9 +44,42 @@ class ApiWrapper:
|
|
|
41
44
|
raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
|
|
42
45
|
"when api_list_paths is a list or tuple.")
|
|
43
46
|
self.api_list_paths = api_list_paths
|
|
47
|
+
self.backlist = backlist if backlist else []
|
|
44
48
|
self.api_names = self._get_api_names()
|
|
45
49
|
self.wrapped_api_functions = dict()
|
|
46
50
|
|
|
51
|
+
@staticmethod
|
|
52
|
+
def deal_with_self_kwargs(api_name, api_func, args, kwargs):
|
|
53
|
+
if kwargs and 'self' in kwargs:
|
|
54
|
+
func_params = None
|
|
55
|
+
try:
|
|
56
|
+
func_params = inspect.signature(api_func).parameters
|
|
57
|
+
except Exception:
|
|
58
|
+
if api_name in Const.API_WITH_SELF_ARG:
|
|
59
|
+
func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters
|
|
60
|
+
if func_params is None:
|
|
61
|
+
return False, args, kwargs
|
|
62
|
+
|
|
63
|
+
for name, param in func_params.items():
|
|
64
|
+
if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY:
|
|
65
|
+
return False, args, kwargs
|
|
66
|
+
args_ = list(args)
|
|
67
|
+
names_and_values = []
|
|
68
|
+
self_index = 0
|
|
69
|
+
for i, item in enumerate(func_params.items()):
|
|
70
|
+
names_and_values.append((item[0], item[1].default))
|
|
71
|
+
if item[0] == 'self':
|
|
72
|
+
self_index = i
|
|
73
|
+
break
|
|
74
|
+
for i in range(len(args), self_index + 1):
|
|
75
|
+
if names_and_values[i][0] in kwargs:
|
|
76
|
+
args_.append(kwargs.pop(names_and_values[i][0]))
|
|
77
|
+
else:
|
|
78
|
+
args_.append(names_and_values[i][1])
|
|
79
|
+
args = tuple(args_)
|
|
80
|
+
|
|
81
|
+
return True, args, kwargs
|
|
82
|
+
|
|
47
83
|
def wrap_api(
|
|
48
84
|
self, api_templates, hook_build_func: Optional[Callable]
|
|
49
85
|
):
|
|
@@ -68,6 +104,14 @@ class ApiWrapper:
|
|
|
68
104
|
if callable(ori_api):
|
|
69
105
|
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
70
106
|
def api_function(*args, **kwargs):
|
|
107
|
+
api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
|
|
108
|
+
enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix,
|
|
109
|
+
api_func, args, kwargs)
|
|
110
|
+
if not enable_wrap:
|
|
111
|
+
logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
|
|
112
|
+
'It may be fixed by passing the value of "self" '
|
|
113
|
+
'as a positional argument instead of a keyword argument. ')
|
|
114
|
+
return api_func(*args, **kwargs)
|
|
71
115
|
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
72
116
|
api_function.__name__ = api_name
|
|
73
117
|
return api_function
|
|
@@ -84,9 +128,12 @@ class ApiWrapper:
|
|
|
84
128
|
api_list = load_yaml(self.api_list_paths[index])
|
|
85
129
|
valid_names = dict()
|
|
86
130
|
for api_type, api_modules in self.api_types.get(framework, {}).items():
|
|
87
|
-
|
|
131
|
+
key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type)
|
|
132
|
+
api_from_file = api_list.get(key_in_file, [])
|
|
88
133
|
names = set()
|
|
89
134
|
for api_name in api_from_file:
|
|
135
|
+
if f'{key_in_file}.{api_name}' in self.backlist:
|
|
136
|
+
continue
|
|
90
137
|
target_attr = api_name
|
|
91
138
|
target_module = api_modules[0]
|
|
92
139
|
if Const.SEP in api_name:
|
|
@@ -105,7 +152,7 @@ class ApiRegistry:
|
|
|
105
152
|
Base class for api registry.
|
|
106
153
|
"""
|
|
107
154
|
|
|
108
|
-
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
|
|
155
|
+
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None):
|
|
109
156
|
self.ori_api_attr = dict()
|
|
110
157
|
self.wrapped_api_attr = dict()
|
|
111
158
|
self.inner_used_ori_attr = dict()
|
|
@@ -114,6 +161,8 @@ class ApiRegistry:
|
|
|
114
161
|
self.inner_used_api = inner_used_api
|
|
115
162
|
self.supported_api_list_path = supported_api_list_path
|
|
116
163
|
self.api_templates = api_templates
|
|
164
|
+
self.backlist = backlist if backlist else []
|
|
165
|
+
self.all_api_registered = False
|
|
117
166
|
|
|
118
167
|
@staticmethod
|
|
119
168
|
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
@@ -131,7 +180,20 @@ class ApiRegistry:
|
|
|
131
180
|
else:
|
|
132
181
|
setattr(api_group, api, api_attr)
|
|
133
182
|
|
|
183
|
+
@staticmethod
|
|
184
|
+
def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template):
|
|
185
|
+
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
186
|
+
def api_function(*args, **kwargs):
|
|
187
|
+
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
188
|
+
|
|
189
|
+
api_function.__name__ = api_name
|
|
190
|
+
return api_function
|
|
191
|
+
|
|
192
|
+
setattr(module, api_name,
|
|
193
|
+
wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template))
|
|
194
|
+
|
|
134
195
|
def register_all_api(self):
|
|
196
|
+
self.all_api_registered = True
|
|
135
197
|
for framework, api_types in self.api_types.items():
|
|
136
198
|
for api_type, api_modules in api_types.items():
|
|
137
199
|
api_type_with_framework = framework + Const.SEP + api_type
|
|
@@ -143,6 +205,7 @@ class ApiRegistry:
|
|
|
143
205
|
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
|
|
144
206
|
|
|
145
207
|
def restore_all_api(self):
|
|
208
|
+
self.all_api_registered = False
|
|
146
209
|
for framework, api_types in self.api_types.items():
|
|
147
210
|
for api_type, api_modules in api_types.items():
|
|
148
211
|
api_type_with_framework = framework + Const.SEP + api_type
|
|
@@ -154,7 +217,7 @@ class ApiRegistry:
|
|
|
154
217
|
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
|
|
155
218
|
|
|
156
219
|
def initialize_hook(self, hook_build_func):
|
|
157
|
-
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
|
|
220
|
+
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist)
|
|
158
221
|
wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
|
|
159
222
|
|
|
160
223
|
for framework, api_types in self.api_types.items():
|