mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.hooks import BackwardHook
|
|
4
|
+
from msprobe.core.common.const import Const
|
|
5
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModuleProcesser:
|
|
9
|
+
module_stack = []
|
|
10
|
+
api_parent_node = ""
|
|
11
|
+
module_node = {}
|
|
12
|
+
current_module_name = ""
|
|
13
|
+
|
|
14
|
+
def __init__(self, scope):
|
|
15
|
+
if isinstance(scope, ModuleRangeScope):
|
|
16
|
+
self.scope = scope
|
|
17
|
+
else:
|
|
18
|
+
self.scope = None
|
|
19
|
+
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
20
|
+
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
21
|
+
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
|
|
22
|
+
self.module_count = {}
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def filter_tensor_and_tuple(func):
|
|
26
|
+
@wraps(func)
|
|
27
|
+
def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
|
|
28
|
+
# setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入
|
|
29
|
+
# setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
|
|
30
|
+
if not isinstance(args[1], (torch.Tensor, tuple)):
|
|
31
|
+
return args[1]
|
|
32
|
+
return func(*args, **kwargs)
|
|
33
|
+
|
|
34
|
+
return wrap_by_filter_tensor_and_tuple
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def clone_return_value(func):
|
|
38
|
+
@wraps(func)
|
|
39
|
+
def clone_return_value_func(*args, **kwargs):
|
|
40
|
+
result = func(*args, **kwargs)
|
|
41
|
+
return ModuleProcesser.clone_if_tensor(result)
|
|
42
|
+
|
|
43
|
+
return clone_return_value_func
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def clone_if_tensor(result):
|
|
47
|
+
if isinstance(result, torch.Tensor):
|
|
48
|
+
return result.clone()
|
|
49
|
+
elif isinstance(result, tuple):
|
|
50
|
+
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
51
|
+
elif isinstance(result, list):
|
|
52
|
+
return list(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
53
|
+
elif isinstance(result, dict):
|
|
54
|
+
return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
|
|
55
|
+
else:
|
|
56
|
+
return result
|
|
57
|
+
|
|
58
|
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
59
|
+
|
|
60
|
+
def pre_hook(module, input, output=None):
|
|
61
|
+
try:
|
|
62
|
+
index = self.module_count_func(name_prefix)
|
|
63
|
+
except IndexError as e:
|
|
64
|
+
index = None
|
|
65
|
+
pass
|
|
66
|
+
module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
|
|
67
|
+
if self.module_stack:
|
|
68
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
69
|
+
else:
|
|
70
|
+
ModuleProcesser.module_node[full_name] = None
|
|
71
|
+
|
|
72
|
+
ModuleProcesser.module_stack.append(full_name)
|
|
73
|
+
if self.module_stack:
|
|
74
|
+
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
75
|
+
if self.scope:
|
|
76
|
+
self.scope.begin_module(full_name)
|
|
77
|
+
|
|
78
|
+
def end_hook(module, input, output=None):
|
|
79
|
+
if self.module_stack:
|
|
80
|
+
ModuleProcesser.module_stack.pop()
|
|
81
|
+
if self.module_stack:
|
|
82
|
+
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
83
|
+
else:
|
|
84
|
+
ModuleProcesser.api_parent_node = None
|
|
85
|
+
if self.scope:
|
|
86
|
+
self.scope.end_module(module.mindstudio_reserved_name)
|
|
87
|
+
|
|
88
|
+
if Const.START in start_or_stop:
|
|
89
|
+
return pre_hook
|
|
90
|
+
else:
|
|
91
|
+
return end_hook
|
|
92
|
+
|
|
93
|
+
def module_count_func(self, module_name):
|
|
94
|
+
if module_name not in self.module_count:
|
|
95
|
+
self.module_count[module_name] = 0
|
|
96
|
+
else:
|
|
97
|
+
self.module_count[module_name] += 1
|
|
98
|
+
return self.module_count[module_name]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024 Huawei Technologies Co., Ltd.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from signal import signal, SIGPIPE, SIG_DFL
|
|
16
|
+
from .dispatch import PtdbgDispatch
|
|
17
|
+
signal(SIGPIPE, SIG_DFL)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
__all__ = ["PtdbgDispatch"]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# 进行比对及结果展示
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import csv
|
|
5
|
+
import json
|
|
6
|
+
from collections import namedtuple
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from .single_compare import single_benchmark_compare_wrap
|
|
10
|
+
from .utils import DispatchException
|
|
11
|
+
from msprobe.core.common.const import CompareConst
|
|
12
|
+
from msprobe.core.common.file_check import FileOpen
|
|
13
|
+
from msprobe.pytorch.common.log import logger
|
|
14
|
+
from msprobe.core.common.utils import CompareException
|
|
15
|
+
|
|
16
|
+
ELEMENT_NUM_THRESHOLD = 100
|
|
17
|
+
ZERO_NUM_THRESHOLD = 0.1
|
|
18
|
+
FLOAT_PRECISION = 14
|
|
19
|
+
|
|
20
|
+
ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
|
|
21
|
+
'fwd_compare_alg_results', 'bwd_compare_alg_results'])
|
|
22
|
+
|
|
23
|
+
def get_file_content_bytes(file):
|
|
24
|
+
with FileOpen(file, 'rb') as file_handle:
|
|
25
|
+
return file_handle.read()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_json_contents(file_path):
|
|
29
|
+
ops = get_file_content_bytes(file_path)
|
|
30
|
+
try:
|
|
31
|
+
json_obj = json.loads(ops)
|
|
32
|
+
except ValueError as error:
|
|
33
|
+
logger.error('Failed to load "%s". %s' % (file_path, str(error)))
|
|
34
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from error
|
|
35
|
+
if not isinstance(json_obj, dict):
|
|
36
|
+
logger.error('Json file %s, content is not a dictionary!' % file_path)
|
|
37
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR)
|
|
38
|
+
return json_obj
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def write_csv(data, filepath):
|
|
42
|
+
with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
|
|
43
|
+
writer = csv.writer(f)
|
|
44
|
+
writer.writerows(data)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Saver:
|
|
48
|
+
# consts for result csv
|
|
49
|
+
COLUMN_API_NAME = "API name"
|
|
50
|
+
COLUMN_FORWARD_SUCCESS = "Forward Test Success"
|
|
51
|
+
COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
|
|
52
|
+
COLUMN_STACK_INFO = "Traceback callstack info"
|
|
53
|
+
|
|
54
|
+
def __init__(self, save_path, detail_save_path, stack_info):
|
|
55
|
+
self.save_path = save_path
|
|
56
|
+
self.detail_save_path = detail_save_path
|
|
57
|
+
self.stack_info = stack_info
|
|
58
|
+
|
|
59
|
+
self.test_result_cnt = {
|
|
60
|
+
"forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
|
|
61
|
+
"total_num": 0, "forward_or_backward_fail_num": 0
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def write_csv_title(self):
|
|
65
|
+
summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
|
|
66
|
+
write_csv(summary_test_rows, self.save_path)
|
|
67
|
+
|
|
68
|
+
detail_test_rows = [[
|
|
69
|
+
"Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
|
|
70
|
+
"error_balance", "max_abs_diff", "max_abs_idx",
|
|
71
|
+
"max_rel_diff", "max_rel_idx", "eb_thd",
|
|
72
|
+
"error_thd", "Status","Message"
|
|
73
|
+
]]
|
|
74
|
+
write_csv(detail_test_rows, self.detail_save_path)
|
|
75
|
+
|
|
76
|
+
def print_pretest_result(self):
|
|
77
|
+
self.get_statistics_from_result_csv()
|
|
78
|
+
if self.test_result_cnt.get("total_num") != 0:
|
|
79
|
+
passing_rate = str(self.test_result_cnt.get("success_num") /
|
|
80
|
+
(self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
|
|
81
|
+
else:
|
|
82
|
+
passing_rate = "0"
|
|
83
|
+
|
|
84
|
+
console = Console()
|
|
85
|
+
table_total = Table(
|
|
86
|
+
show_header=True, title="Overall Statistics", show_lines=True, width=75
|
|
87
|
+
)
|
|
88
|
+
table_total.add_column("Result")
|
|
89
|
+
table_total.add_column("Statistics")
|
|
90
|
+
table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
|
|
91
|
+
table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
|
|
92
|
+
self.test_result_cnt.get("forward_or_backward_fail_num")))
|
|
93
|
+
table_total.add_row("Passing Rate", passing_rate)
|
|
94
|
+
|
|
95
|
+
table_detail = Table(
|
|
96
|
+
show_header=True, title="Detail Statistics", show_lines=True, width=75
|
|
97
|
+
)
|
|
98
|
+
table_detail.add_column("Result")
|
|
99
|
+
table_detail.add_column("Statistics")
|
|
100
|
+
table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
|
|
101
|
+
table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
|
|
102
|
+
table_detail.add_row(
|
|
103
|
+
"Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
|
|
104
|
+
|
|
105
|
+
console.print(table_total)
|
|
106
|
+
console.print(table_detail)
|
|
107
|
+
|
|
108
|
+
def get_statistics_from_result_csv(self):
|
|
109
|
+
checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
|
|
110
|
+
with FileOpen(self.save_path, 'r') as file:
|
|
111
|
+
reader = csv.reader(file)
|
|
112
|
+
result_csv_rows = [row for row in reader]
|
|
113
|
+
result_csv_name = os.path.basename(self.save_path)
|
|
114
|
+
for item in result_csv_rows[1:]:
|
|
115
|
+
if not isinstance(item, list) or len(item) < 3:
|
|
116
|
+
raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
|
|
117
|
+
if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
|
|
120
|
+
% result_csv_name)
|
|
121
|
+
column1 = item[1].upper()
|
|
122
|
+
column2 = item[2].upper()
|
|
123
|
+
if column1 == CompareConst.SKIP:
|
|
124
|
+
continue
|
|
125
|
+
self.test_result_cnt["total_num"] += 1
|
|
126
|
+
if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
|
|
127
|
+
self.test_result_cnt['success_num'] += 1
|
|
128
|
+
elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
|
|
129
|
+
self.test_result_cnt['forward_and_backward_fail_num'] += 1
|
|
130
|
+
elif column1 == CompareConst.FALSE:
|
|
131
|
+
self.test_result_cnt['forward_fail_num'] += 1
|
|
132
|
+
self.test_result_cnt['forward_or_backward_fail_num'] += 1
|
|
133
|
+
else:
|
|
134
|
+
self.test_result_cnt['backward_fail_num'] += 1
|
|
135
|
+
self.test_result_cnt['forward_or_backward_fail_num'] += 1
|
|
136
|
+
|
|
137
|
+
def write_summary_csv(self, test_result):
|
|
138
|
+
test_rows = []
|
|
139
|
+
if self.stack_info:
|
|
140
|
+
test_rows[0].append(self.COLUMN_STACK_INFO)
|
|
141
|
+
|
|
142
|
+
name = test_result.api_name
|
|
143
|
+
df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
|
|
144
|
+
if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
|
|
145
|
+
df_row.append(test_result.fwd_compare_alg_results)
|
|
146
|
+
if self.stack_info:
|
|
147
|
+
stack_info = "\n".join(self.stack_info[name])
|
|
148
|
+
df_row.append(stack_info)
|
|
149
|
+
test_rows.append(df_row)
|
|
150
|
+
write_csv(test_rows, self.save_path)
|
|
151
|
+
|
|
152
|
+
def write_detail_csv(self, test_result):
|
|
153
|
+
def get_rows_from_list(result, name, sub_prefix):
|
|
154
|
+
rows = []
|
|
155
|
+
if isinstance(result, list):
|
|
156
|
+
for i, test_subject in enumerate(result):
|
|
157
|
+
subject = sub_prefix + "." + name + ".output." + str(i)
|
|
158
|
+
test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
|
|
159
|
+
item in test_subject]
|
|
160
|
+
rows.append([subject] + list(test_subject))
|
|
161
|
+
return rows
|
|
162
|
+
|
|
163
|
+
test_rows = []
|
|
164
|
+
subject_prefix = test_result.api_name
|
|
165
|
+
fwd_result = test_result.fwd_compare_alg_results
|
|
166
|
+
bwd_result = test_result.bwd_compare_alg_results
|
|
167
|
+
|
|
168
|
+
test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
|
|
169
|
+
test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
|
|
170
|
+
|
|
171
|
+
write_csv(test_rows, self.detail_save_path)
|
|
172
|
+
|
|
173
|
+
def record_results(self, result_info):
|
|
174
|
+
self.write_summary_csv(result_info)
|
|
175
|
+
self.write_detail_csv(result_info)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class Comparator:
|
|
179
|
+
|
|
180
|
+
def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
|
|
181
|
+
self.save_path = result_csv_path
|
|
182
|
+
self.detail_save_path = details_csv_path
|
|
183
|
+
if stack_info_json_path:
|
|
184
|
+
self.stack_info = get_json_contents(stack_info_json_path)
|
|
185
|
+
else:
|
|
186
|
+
self.stack_info = None
|
|
187
|
+
self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
|
|
188
|
+
|
|
189
|
+
if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
|
|
190
|
+
self.saver.write_csv_title()
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _compare_core_wrapper(bench_out, npu_out):
|
|
194
|
+
detailed_result_total = []
|
|
195
|
+
test_final_success = True
|
|
196
|
+
status, details = single_benchmark_compare_wrap(npu_out, bench_out)
|
|
197
|
+
if not isinstance(status, list):
|
|
198
|
+
detailed_result_total.append(details)
|
|
199
|
+
test_final_success = status
|
|
200
|
+
else:
|
|
201
|
+
for item, item_status in enumerate(status):
|
|
202
|
+
detailed_result_total.append(details.get(item, 'key does not exist'))
|
|
203
|
+
if not item_status:
|
|
204
|
+
test_final_success = False
|
|
205
|
+
return test_final_success, detailed_result_total
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _compare_dropout(bench_out, npu_out):
|
|
209
|
+
tensor_num = bench_out.numel()
|
|
210
|
+
if tensor_num >= ELEMENT_NUM_THRESHOLD:
|
|
211
|
+
if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
|
|
212
|
+
return True, 1
|
|
213
|
+
else:
|
|
214
|
+
return False, 0
|
|
215
|
+
else:
|
|
216
|
+
return True, 1
|
|
217
|
+
|
|
218
|
+
def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
|
|
219
|
+
if "dropout" in api_name:
|
|
220
|
+
is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
|
|
221
|
+
else:
|
|
222
|
+
is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
|
|
223
|
+
if bench_grad and npu_grad:
|
|
224
|
+
if "dropout" in api_name:
|
|
225
|
+
is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
|
|
226
|
+
else:
|
|
227
|
+
is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
|
|
228
|
+
else:
|
|
229
|
+
is_bwd_success, bwd_compare_alg_results = True, None
|
|
230
|
+
if is_bwd_success and bwd_compare_alg_results is None:
|
|
231
|
+
self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results,
|
|
232
|
+
bwd_compare_alg_results))
|
|
233
|
+
else:
|
|
234
|
+
self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
|
|
235
|
+
bwd_compare_alg_results))
|
|
236
|
+
return is_fwd_success, is_bwd_success
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from multiprocessing import Manager, Pool
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import torch_npu
|
|
14
|
+
except ImportError:
|
|
15
|
+
is_npu = False
|
|
16
|
+
else:
|
|
17
|
+
is_npu = True
|
|
18
|
+
|
|
19
|
+
from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
|
|
20
|
+
DispatchRunParam, DisPatchDataInfo
|
|
21
|
+
from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
|
|
22
|
+
DispatchException
|
|
23
|
+
from .compare import Comparator
|
|
24
|
+
from msprobe.core.common.file_check import FileOpen
|
|
25
|
+
from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
|
|
26
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
27
|
+
|
|
28
|
+
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
29
|
+
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
30
|
+
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PtdbgDispatch(TorchDispatchMode):
|
|
34
|
+
def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
|
|
35
|
+
super(PtdbgDispatch, self).__init__()
|
|
36
|
+
logger_logo()
|
|
37
|
+
if not is_npu:
|
|
38
|
+
logger_error("Please confirm you run environment installed torch_npu!")
|
|
39
|
+
return
|
|
40
|
+
if dump_path is None:
|
|
41
|
+
logger_error("Please set dump_path when dump_mode is config!")
|
|
42
|
+
check_file_or_directory_path(dump_path, True)
|
|
43
|
+
|
|
44
|
+
self.device_id = torch_npu._C._npu_getDevice()
|
|
45
|
+
self.dump_mode = dump_mode
|
|
46
|
+
self.dump_api_list = api_list
|
|
47
|
+
self.debug_flag = debug
|
|
48
|
+
self.api_index = 0
|
|
49
|
+
self.single_api_index_dict = {}
|
|
50
|
+
self.device_dump_path_cpu = None
|
|
51
|
+
self.device_dump_path_npu = None
|
|
52
|
+
self.all_summery = []
|
|
53
|
+
self.call_stack_list = []
|
|
54
|
+
self.process_num = process_num
|
|
55
|
+
self.filter_dump_api()
|
|
56
|
+
self.check_param()
|
|
57
|
+
dir_name = self.get_dir_name(tag)
|
|
58
|
+
self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
|
|
59
|
+
self.root_cpu_path = os.path.join(self.root_path, f'cpu')
|
|
60
|
+
self.root_npu_path = os.path.join(self.root_path, f'npu')
|
|
61
|
+
check_path_before_create(self.root_cpu_path)
|
|
62
|
+
check_path_before_create(self.root_npu_path)
|
|
63
|
+
Path(self.root_cpu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
|
|
64
|
+
Path(self.root_npu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
|
|
65
|
+
|
|
66
|
+
self.result_csv_path = os.path.join(self.root_path, RESULT_FILE_NAME)
|
|
67
|
+
self.detail_csv_path = os.path.join(self.root_path, DETAILS_FILE_NAME)
|
|
68
|
+
self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
|
|
69
|
+
|
|
70
|
+
self.aten_ops_blacklist = []
|
|
71
|
+
self.npu_adjust_autogard = []
|
|
72
|
+
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
73
|
+
self.load_yaml_file(yaml_path)
|
|
74
|
+
|
|
75
|
+
self.lock = None
|
|
76
|
+
if process_num > 0:
|
|
77
|
+
self.pool = Pool(process_num)
|
|
78
|
+
if debug:
|
|
79
|
+
logger_debug(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
80
|
+
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
81
|
+
f'process[{process_num}]')
|
|
82
|
+
|
|
83
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
84
|
+
super().__exit__(exc_type, exc_val, exc_tb)
|
|
85
|
+
|
|
86
|
+
if not is_npu:
|
|
87
|
+
return
|
|
88
|
+
logger_debug(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
|
|
89
|
+
|
|
90
|
+
if self.process_num > 0:
|
|
91
|
+
self.pool.close()
|
|
92
|
+
self.pool.join()
|
|
93
|
+
summery_path = os.path.join(self.root_cpu_path, f'summary.json')
|
|
94
|
+
if not os.path.exists(summery_path):
|
|
95
|
+
logger_error("Please check train log, An exception may have occurred!")
|
|
96
|
+
return
|
|
97
|
+
check_file_or_directory_path(summery_path, False)
|
|
98
|
+
fp_handle = open(summery_path, "r")
|
|
99
|
+
while True:
|
|
100
|
+
json_line_data = fp_handle.readline()
|
|
101
|
+
if json_line_data == '\n':
|
|
102
|
+
continue
|
|
103
|
+
if len(json_line_data) == 0:
|
|
104
|
+
break
|
|
105
|
+
msg = json.loads(json_line_data)
|
|
106
|
+
self.all_summery[msg[0]] = msg[1]
|
|
107
|
+
fp_handle.close()
|
|
108
|
+
|
|
109
|
+
if self.debug_flag:
|
|
110
|
+
input_num = 0
|
|
111
|
+
output_num = 0
|
|
112
|
+
total_num = 0
|
|
113
|
+
|
|
114
|
+
for list_data in self.all_summery:
|
|
115
|
+
for data in list_data:
|
|
116
|
+
logger_debug(f'summery: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
|
|
117
|
+
if "_input" in data[CompareConst.NPU_NAME]:
|
|
118
|
+
input_num = input_num + 1
|
|
119
|
+
if "_output" in data[CompareConst.NPU_NAME]:
|
|
120
|
+
output_num = output_num + 1
|
|
121
|
+
total_num = total_num + 1
|
|
122
|
+
logger_debug(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
|
|
123
|
+
f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
|
|
124
|
+
|
|
125
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
126
|
+
if not is_npu:
|
|
127
|
+
logger_error("Please confirm you run environment installed torch_npu!")
|
|
128
|
+
return func(*args, **kwargs)
|
|
129
|
+
|
|
130
|
+
func_name_split_list = func.__name__.split(".")
|
|
131
|
+
aten_api = func_name_split_list[0]
|
|
132
|
+
try:
|
|
133
|
+
aten_api_overload_name = func_name_split_list[1]
|
|
134
|
+
except IndexError:
|
|
135
|
+
logger_error(f"Please check the func name {func.__name__}!")
|
|
136
|
+
return func(*args, **kwargs)
|
|
137
|
+
|
|
138
|
+
self.enable_autogard(aten_api)
|
|
139
|
+
if aten_api in self.aten_ops_blacklist:
|
|
140
|
+
npu_out = func(*args, **kwargs)
|
|
141
|
+
return npu_out
|
|
142
|
+
|
|
143
|
+
call_stack = get_callstack()
|
|
144
|
+
self.call_stack_list.append(call_stack)
|
|
145
|
+
self.api_index += 1
|
|
146
|
+
if aten_api not in self.single_api_index_dict:
|
|
147
|
+
self.single_api_index_dict[aten_api] = 1
|
|
148
|
+
else:
|
|
149
|
+
self.single_api_index_dict[aten_api] += 1
|
|
150
|
+
|
|
151
|
+
run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
|
|
152
|
+
|
|
153
|
+
if self.debug_flag:
|
|
154
|
+
logger_debug(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
|
|
155
|
+
f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
|
|
156
|
+
f'Count[{self.api_index}], Sys[{get_sys_info()}]')
|
|
157
|
+
|
|
158
|
+
cpu_args = []
|
|
159
|
+
cpu_kwargs = []
|
|
160
|
+
data_to_cpu(args, 0, cpu_args)
|
|
161
|
+
data_to_cpu(kwargs, 0, cpu_kwargs)
|
|
162
|
+
cpu_args = cpu_args[0]
|
|
163
|
+
cpu_kwargs = cpu_kwargs[0]
|
|
164
|
+
|
|
165
|
+
with TimeStatistics("NPU RUN", run_param):
|
|
166
|
+
npu_out = func(*args, **kwargs)
|
|
167
|
+
npu_out_cpu = []
|
|
168
|
+
data_to_cpu(npu_out, 0, npu_out_cpu)
|
|
169
|
+
npu_out_cpu = npu_out_cpu[0]
|
|
170
|
+
|
|
171
|
+
with TimeStatistics("CPU RUN", run_param):
|
|
172
|
+
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
173
|
+
|
|
174
|
+
if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
|
|
175
|
+
cpu_out = cpu_out.float()
|
|
176
|
+
|
|
177
|
+
if self.process_num == 0:
|
|
178
|
+
self.all_summery.append([])
|
|
179
|
+
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, func, npu_out_cpu, cpu_out, self.lock)
|
|
180
|
+
dispatch_workflow(run_param, data_info)
|
|
181
|
+
else:
|
|
182
|
+
self.lock.acquire()
|
|
183
|
+
self.all_summery.append([])
|
|
184
|
+
self.lock.release()
|
|
185
|
+
run_param.process_flag = True
|
|
186
|
+
if self.check_fun(func, run_param):
|
|
187
|
+
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, None, npu_out_cpu, cpu_out,
|
|
188
|
+
self.lock)
|
|
189
|
+
self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
|
|
190
|
+
error_callback=error_call)
|
|
191
|
+
else:
|
|
192
|
+
logger_error("can not get correct function please set process_num=0")
|
|
193
|
+
return npu_out
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def check_fun(func, run_param):
|
|
197
|
+
if hasattr(torch.ops.aten, run_param.aten_api):
|
|
198
|
+
aten_func = getattr(torch.ops.aten, run_param.aten_api)
|
|
199
|
+
if hasattr(aten_func, run_param.aten_api_overload_name):
|
|
200
|
+
aten_overload_func = getattr(aten_func, run_param.aten_api_overload_name)
|
|
201
|
+
if id(aten_overload_func) == id(func):
|
|
202
|
+
run_param.func_namespace = "aten"
|
|
203
|
+
return True
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
def get_dir_name(self, tag):
|
|
207
|
+
# guarantee file uniqueness
|
|
208
|
+
time.sleep(1)
|
|
209
|
+
time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
|
210
|
+
if tag is None or not isinstance(tag, str):
|
|
211
|
+
logger_warn('There is not tag or the type of tag is not string.')
|
|
212
|
+
dir_name = f'msprobe_rank{self.device_id}_{time_now}'
|
|
213
|
+
else:
|
|
214
|
+
dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
|
|
215
|
+
return dir_name
|
|
216
|
+
|
|
217
|
+
def load_yaml_file(self, file_path):
|
|
218
|
+
with FileOpen(file_path, 'r') as f:
|
|
219
|
+
yaml_file = yaml.safe_load(f)
|
|
220
|
+
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
|
|
221
|
+
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
|
|
222
|
+
|
|
223
|
+
def filter_dump_api(self):
|
|
224
|
+
if self.dump_mode != Const.LIST or not self.dump_api_list:
|
|
225
|
+
self.dump_api_list = []
|
|
226
|
+
return
|
|
227
|
+
aten_api_list = dir(torch.ops.aten)
|
|
228
|
+
dump_api_list = []
|
|
229
|
+
for aten_api in self.dump_api_list:
|
|
230
|
+
if aten_api in aten_api_list:
|
|
231
|
+
dump_api_list.append(aten_api)
|
|
232
|
+
else:
|
|
233
|
+
logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
|
|
234
|
+
self.dump_api_list = dump_api_list
|
|
235
|
+
|
|
236
|
+
def get_run_param(self, aten_api, func_name, aten_api_overload_name):
|
|
237
|
+
run_param = DispatchRunParam(self.debug_flag, self.device_id, self.root_npu_path, self.root_cpu_path,
|
|
238
|
+
self.process_num, self.comparator)
|
|
239
|
+
run_param.dump_flag, run_param.auto_dump_flag = self.get_dump_flag(aten_api)
|
|
240
|
+
run_param.func_name = func_name
|
|
241
|
+
run_param.aten_api = aten_api
|
|
242
|
+
run_param.aten_api_overload_name = aten_api_overload_name
|
|
243
|
+
run_param.single_api_index = self.single_api_index_dict[aten_api]
|
|
244
|
+
run_param.api_index = self.api_index
|
|
245
|
+
return run_param
|
|
246
|
+
|
|
247
|
+
def get_dump_flag(self, aten_api):
|
|
248
|
+
dump_flag = False
|
|
249
|
+
auto_dump_flag = False
|
|
250
|
+
if self.dump_mode == Const.ALL:
|
|
251
|
+
dump_flag = True
|
|
252
|
+
if self.dump_mode == Const.LIST and aten_api in self.dump_api_list:
|
|
253
|
+
dump_flag = True
|
|
254
|
+
if self.dump_mode == Const.AUTO:
|
|
255
|
+
auto_dump_flag = True
|
|
256
|
+
return dump_flag, auto_dump_flag
|
|
257
|
+
|
|
258
|
+
def check_param(self):
|
|
259
|
+
if self.dump_mode not in Const.ONLINE_DUMP_MODE:
|
|
260
|
+
logger_error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
|
|
261
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
262
|
+
if not isinstance(self.dump_api_list, list):
|
|
263
|
+
logger_error('The type of parameter "api_list" can only be list.')
|
|
264
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
265
|
+
if not isinstance(self.debug_flag, bool):
|
|
266
|
+
logger_error('The type of parameter "debug" can only be bool.')
|
|
267
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
268
|
+
if not isinstance(self.process_num, int) or self.process_num < 0:
|
|
269
|
+
logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.')
|
|
270
|
+
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
271
|
+
|
|
272
|
+
def enable_autogard(self, aten_api):
|
|
273
|
+
if aten_api in self.npu_adjust_autogard:
|
|
274
|
+
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|