mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -1,15 +1,18 @@
|
|
|
1
1
|
from functools import wraps
|
|
2
|
+
|
|
2
3
|
import torch
|
|
3
4
|
from torch.utils.hooks import BackwardHook
|
|
5
|
+
|
|
4
6
|
from msprobe.core.common.const import Const
|
|
5
7
|
from msprobe.core.data_dump.scope import ModuleRangeScope
|
|
8
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class ModuleProcesser:
|
|
12
|
+
module_count = {}
|
|
9
13
|
module_stack = []
|
|
10
14
|
api_parent_node = ""
|
|
11
15
|
module_node = {}
|
|
12
|
-
current_module_name = ""
|
|
13
16
|
|
|
14
17
|
def __init__(self, scope):
|
|
15
18
|
if isinstance(scope, ModuleRangeScope):
|
|
@@ -19,15 +22,22 @@ class ModuleProcesser:
|
|
|
19
22
|
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
20
23
|
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
21
24
|
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
|
|
22
|
-
self.module_count = {}
|
|
23
25
|
|
|
24
26
|
@staticmethod
|
|
25
27
|
def filter_tensor_and_tuple(func):
|
|
26
28
|
@wraps(func)
|
|
27
29
|
def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
|
|
28
|
-
# setup_output_hook传入非tensor数据,工具后续dump
|
|
30
|
+
# setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
|
|
29
31
|
# setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
|
|
30
32
|
if not isinstance(args[1], (torch.Tensor, tuple)):
|
|
33
|
+
for item_str in dir(args[1]):
|
|
34
|
+
item = getattr(args[1], item_str)
|
|
35
|
+
# 处理tensor或者只包含tensor的元组
|
|
36
|
+
if isinstance(item, torch.Tensor) or \
|
|
37
|
+
(isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
|
|
38
|
+
args_new = (args[0], item)
|
|
39
|
+
result = func(*args_new, **kwargs)
|
|
40
|
+
setattr(args[1], item_str, result)
|
|
31
41
|
return args[1]
|
|
32
42
|
return func(*args, **kwargs)
|
|
33
43
|
|
|
@@ -55,11 +65,26 @@ class ModuleProcesser:
|
|
|
55
65
|
else:
|
|
56
66
|
return result
|
|
57
67
|
|
|
68
|
+
@staticmethod
|
|
69
|
+
def module_count_func(module_name):
|
|
70
|
+
if module_name not in ModuleProcesser.module_count:
|
|
71
|
+
ModuleProcesser.module_count[module_name] = 0
|
|
72
|
+
else:
|
|
73
|
+
ModuleProcesser.module_count[module_name] += 1
|
|
74
|
+
return ModuleProcesser.module_count[module_name]
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def reset_module_stats(cls):
|
|
78
|
+
cls.module_count = {}
|
|
79
|
+
cls.module_stack = []
|
|
80
|
+
cls.api_parent_node = ""
|
|
81
|
+
cls.module_node = {}
|
|
82
|
+
|
|
58
83
|
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
59
84
|
|
|
60
85
|
def pre_hook(module, input, output=None):
|
|
61
86
|
try:
|
|
62
|
-
index =
|
|
87
|
+
index = ModuleProcesser.module_count_func(name_prefix)
|
|
63
88
|
except IndexError as e:
|
|
64
89
|
index = None
|
|
65
90
|
pass
|
|
@@ -85,14 +110,29 @@ class ModuleProcesser:
|
|
|
85
110
|
if self.scope:
|
|
86
111
|
self.scope.end_module(module.mindstudio_reserved_name)
|
|
87
112
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
113
|
+
def backward_hook(module, input, output=None):
|
|
114
|
+
try:
|
|
115
|
+
index = ModuleProcesser.module_count_func(name_prefix)
|
|
116
|
+
except IndexError as e:
|
|
117
|
+
index = None
|
|
118
|
+
pass
|
|
119
|
+
module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
|
|
120
|
+
forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
|
|
121
|
+
ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
|
|
122
|
+
Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
|
|
123
|
+
ModuleProcesser.api_parent_node = None
|
|
124
|
+
if self.scope:
|
|
125
|
+
self.scope.begin_module(full_name)
|
|
92
126
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
127
|
+
if torch_version_above_or_equal_2:
|
|
128
|
+
if Const.START in start_or_stop:
|
|
129
|
+
return pre_hook
|
|
130
|
+
else:
|
|
131
|
+
return end_hook
|
|
96
132
|
else:
|
|
97
|
-
|
|
98
|
-
|
|
133
|
+
if Const.FORWARD in name_prefix and Const.START in start_or_stop:
|
|
134
|
+
return pre_hook
|
|
135
|
+
elif Const.BACKWARD in name_prefix:
|
|
136
|
+
return backward_hook
|
|
137
|
+
else:
|
|
138
|
+
return end_hook
|
|
@@ -6,10 +6,9 @@ import json
|
|
|
6
6
|
from collections import namedtuple
|
|
7
7
|
from rich.table import Table
|
|
8
8
|
from rich.console import Console
|
|
9
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
10
|
+
from msprobe.core.common.file_check import FileOpen, change_mode
|
|
9
11
|
from .single_compare import single_benchmark_compare_wrap
|
|
10
|
-
from .utils import DispatchException
|
|
11
|
-
from msprobe.core.common.const import CompareConst
|
|
12
|
-
from msprobe.core.common.file_check import FileOpen
|
|
13
12
|
from msprobe.pytorch.common.log import logger
|
|
14
13
|
from msprobe.core.common.utils import CompareException
|
|
15
14
|
|
|
@@ -42,6 +41,7 @@ def write_csv(data, filepath):
|
|
|
42
41
|
with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
|
|
43
42
|
writer = csv.writer(f)
|
|
44
43
|
writer.writerows(data)
|
|
44
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
class Saver:
|
|
@@ -228,7 +228,7 @@ class Comparator:
|
|
|
228
228
|
else:
|
|
229
229
|
is_bwd_success, bwd_compare_alg_results = True, None
|
|
230
230
|
if is_bwd_success and bwd_compare_alg_results is None:
|
|
231
|
-
self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.
|
|
231
|
+
self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
|
|
232
232
|
bwd_compare_alg_results))
|
|
233
233
|
else:
|
|
234
234
|
self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
|
|
@@ -4,7 +4,6 @@ import json
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from multiprocessing import Manager, Pool
|
|
6
6
|
|
|
7
|
-
import yaml
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
9
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
@@ -16,14 +15,14 @@ except ImportError:
|
|
|
16
15
|
else:
|
|
17
16
|
is_npu = True
|
|
18
17
|
|
|
18
|
+
from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, load_yaml
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
19
21
|
from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
|
|
20
22
|
DispatchRunParam, DisPatchDataInfo
|
|
21
|
-
from .utils import get_callstack, data_to_cpu,
|
|
22
|
-
DispatchException
|
|
23
|
+
from .utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, COMPARE_LOGO
|
|
23
24
|
from .compare import Comparator
|
|
24
|
-
|
|
25
|
-
from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
|
|
26
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
25
|
+
|
|
27
26
|
|
|
28
27
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
29
28
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
@@ -33,12 +32,12 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
33
32
|
class PtdbgDispatch(TorchDispatchMode):
|
|
34
33
|
def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
|
|
35
34
|
super(PtdbgDispatch, self).__init__()
|
|
36
|
-
|
|
35
|
+
logger.info(COMPARE_LOGO)
|
|
37
36
|
if not is_npu:
|
|
38
|
-
|
|
37
|
+
logger.error("Please confirm you run environment installed torch_npu!")
|
|
39
38
|
return
|
|
40
39
|
if dump_path is None:
|
|
41
|
-
|
|
40
|
+
logger.error("Please set dump_path when dump_mode is config!")
|
|
42
41
|
check_file_or_directory_path(dump_path, True)
|
|
43
42
|
|
|
44
43
|
self.device_id = torch_npu._C._npu_getDevice()
|
|
@@ -49,7 +48,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
49
48
|
self.single_api_index_dict = {}
|
|
50
49
|
self.device_dump_path_cpu = None
|
|
51
50
|
self.device_dump_path_npu = None
|
|
52
|
-
self.
|
|
51
|
+
self.all_summary = []
|
|
53
52
|
self.call_stack_list = []
|
|
54
53
|
self.process_num = process_num
|
|
55
54
|
self.filter_dump_api()
|
|
@@ -70,13 +69,13 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
70
69
|
self.aten_ops_blacklist = []
|
|
71
70
|
self.npu_adjust_autogard = []
|
|
72
71
|
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
73
|
-
self.
|
|
72
|
+
self.get_ops(yaml_path)
|
|
74
73
|
|
|
75
74
|
self.lock = None
|
|
76
75
|
if process_num > 0:
|
|
77
76
|
self.pool = Pool(process_num)
|
|
78
77
|
if debug:
|
|
79
|
-
|
|
78
|
+
logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
80
79
|
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
81
80
|
f'process[{process_num}]')
|
|
82
81
|
|
|
@@ -85,17 +84,17 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
85
84
|
|
|
86
85
|
if not is_npu:
|
|
87
86
|
return
|
|
88
|
-
|
|
87
|
+
logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
|
|
89
88
|
|
|
90
89
|
if self.process_num > 0:
|
|
91
90
|
self.pool.close()
|
|
92
91
|
self.pool.join()
|
|
93
|
-
|
|
94
|
-
if not os.path.exists(
|
|
95
|
-
|
|
92
|
+
summary_path = os.path.join(self.root_cpu_path, f'summary.json')
|
|
93
|
+
if not os.path.exists(summary_path):
|
|
94
|
+
logger.error("Please check train log, An exception may have occurred!")
|
|
96
95
|
return
|
|
97
|
-
check_file_or_directory_path(
|
|
98
|
-
fp_handle = open(
|
|
96
|
+
check_file_or_directory_path(summary_path, False)
|
|
97
|
+
fp_handle = open(summary_path, "r")
|
|
99
98
|
while True:
|
|
100
99
|
json_line_data = fp_handle.readline()
|
|
101
100
|
if json_line_data == '\n':
|
|
@@ -103,7 +102,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
103
102
|
if len(json_line_data) == 0:
|
|
104
103
|
break
|
|
105
104
|
msg = json.loads(json_line_data)
|
|
106
|
-
self.
|
|
105
|
+
self.all_summary[msg[0]] = msg[1]
|
|
107
106
|
fp_handle.close()
|
|
108
107
|
|
|
109
108
|
if self.debug_flag:
|
|
@@ -111,20 +110,20 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
111
110
|
output_num = 0
|
|
112
111
|
total_num = 0
|
|
113
112
|
|
|
114
|
-
for list_data in self.
|
|
113
|
+
for list_data in self.all_summary:
|
|
115
114
|
for data in list_data:
|
|
116
|
-
|
|
115
|
+
logger.info(f'summary: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
|
|
117
116
|
if "_input" in data[CompareConst.NPU_NAME]:
|
|
118
117
|
input_num = input_num + 1
|
|
119
118
|
if "_output" in data[CompareConst.NPU_NAME]:
|
|
120
119
|
output_num = output_num + 1
|
|
121
120
|
total_num = total_num + 1
|
|
122
|
-
|
|
121
|
+
logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
|
|
123
122
|
f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
|
|
124
123
|
|
|
125
124
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
126
125
|
if not is_npu:
|
|
127
|
-
|
|
126
|
+
logger.error("Please confirm you run environment installed torch_npu!")
|
|
128
127
|
return func(*args, **kwargs)
|
|
129
128
|
|
|
130
129
|
func_name_split_list = func.__name__.split(".")
|
|
@@ -132,7 +131,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
132
131
|
try:
|
|
133
132
|
aten_api_overload_name = func_name_split_list[1]
|
|
134
133
|
except IndexError:
|
|
135
|
-
|
|
134
|
+
logger.error(f"Please check the func name {func.__name__}!")
|
|
136
135
|
return func(*args, **kwargs)
|
|
137
136
|
|
|
138
137
|
self.enable_autogard(aten_api)
|
|
@@ -151,7 +150,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
151
150
|
run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
|
|
152
151
|
|
|
153
152
|
if self.debug_flag:
|
|
154
|
-
|
|
153
|
+
logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
|
|
155
154
|
f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
|
|
156
155
|
f'Count[{self.api_index}], Sys[{get_sys_info()}]')
|
|
157
156
|
|
|
@@ -175,21 +174,21 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
175
174
|
cpu_out = cpu_out.float()
|
|
176
175
|
|
|
177
176
|
if self.process_num == 0:
|
|
178
|
-
self.
|
|
179
|
-
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.
|
|
177
|
+
self.all_summary.append([])
|
|
178
|
+
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out, self.lock)
|
|
180
179
|
dispatch_workflow(run_param, data_info)
|
|
181
180
|
else:
|
|
182
181
|
self.lock.acquire()
|
|
183
|
-
self.
|
|
182
|
+
self.all_summary.append([])
|
|
184
183
|
self.lock.release()
|
|
185
184
|
run_param.process_flag = True
|
|
186
185
|
if self.check_fun(func, run_param):
|
|
187
|
-
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.
|
|
186
|
+
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
188
187
|
self.lock)
|
|
189
188
|
self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
|
|
190
189
|
error_callback=error_call)
|
|
191
190
|
else:
|
|
192
|
-
|
|
191
|
+
logger.error("can not get correct function please set process_num=0")
|
|
193
192
|
return npu_out
|
|
194
193
|
|
|
195
194
|
@staticmethod
|
|
@@ -208,17 +207,16 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
208
207
|
time.sleep(1)
|
|
209
208
|
time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
|
|
210
209
|
if tag is None or not isinstance(tag, str):
|
|
211
|
-
|
|
210
|
+
logger.warning('There is not tag or the type of tag is not string.')
|
|
212
211
|
dir_name = f'msprobe_rank{self.device_id}_{time_now}'
|
|
213
212
|
else:
|
|
214
213
|
dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
|
|
215
214
|
return dir_name
|
|
216
215
|
|
|
217
|
-
def
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
|
|
216
|
+
def get_ops(self, file_path):
|
|
217
|
+
yaml_file = load_yaml(file_path)
|
|
218
|
+
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
|
|
219
|
+
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
|
|
222
220
|
|
|
223
221
|
def filter_dump_api(self):
|
|
224
222
|
if self.dump_mode != Const.LIST or not self.dump_api_list:
|
|
@@ -230,7 +228,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
230
228
|
if aten_api in aten_api_list:
|
|
231
229
|
dump_api_list.append(aten_api)
|
|
232
230
|
else:
|
|
233
|
-
|
|
231
|
+
logger.warning(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
|
|
234
232
|
self.dump_api_list = dump_api_list
|
|
235
233
|
|
|
236
234
|
def get_run_param(self, aten_api, func_name, aten_api_overload_name):
|
|
@@ -257,16 +255,16 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
257
255
|
|
|
258
256
|
def check_param(self):
|
|
259
257
|
if self.dump_mode not in Const.ONLINE_DUMP_MODE:
|
|
260
|
-
|
|
258
|
+
logger.error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
|
|
261
259
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
262
260
|
if not isinstance(self.dump_api_list, list):
|
|
263
|
-
|
|
261
|
+
logger.error('The type of parameter "api_list" can only be list.')
|
|
264
262
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
265
263
|
if not isinstance(self.debug_flag, bool):
|
|
266
|
-
|
|
264
|
+
logger.error('The type of parameter "debug" can only be bool.')
|
|
267
265
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
268
266
|
if not isinstance(self.process_num, int) or self.process_num < 0:
|
|
269
|
-
|
|
267
|
+
logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.')
|
|
270
268
|
raise DispatchException(DispatchException.INVALID_PARAMETER)
|
|
271
269
|
|
|
272
270
|
def enable_autogard(self, aten_api):
|
|
@@ -5,11 +5,10 @@ from datetime import datetime, timezone
|
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import torch
|
|
8
|
-
from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \
|
|
9
|
-
COLOR_RESET, CSV_COLUMN_NAME
|
|
10
|
-
from msprobe.core.common.file_check import FileOpen, change_mode
|
|
11
|
-
from msprobe.core.common.const import CompareConst, FileCheckConst, Const
|
|
12
8
|
from msprobe.pytorch.common.log import logger
|
|
9
|
+
from msprobe.core.common.file_check import FileOpen
|
|
10
|
+
from .utils import np_save_data
|
|
11
|
+
|
|
13
12
|
|
|
14
13
|
class DispatchRunParam:
|
|
15
14
|
def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator):
|
|
@@ -32,10 +31,10 @@ class DispatchRunParam:
|
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
class DisPatchDataInfo:
|
|
35
|
-
def __init__(self, cpu_args, cpu_kwargs,
|
|
34
|
+
def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock):
|
|
36
35
|
self.cpu_args = cpu_args
|
|
37
36
|
self.cpu_kwargs = cpu_kwargs
|
|
38
|
-
self.
|
|
37
|
+
self.all_summary = all_summary
|
|
39
38
|
self.func = func
|
|
40
39
|
self.npu_out_cpu = npu_out_cpu
|
|
41
40
|
self.cpu_out = cpu_out
|
|
@@ -57,7 +56,7 @@ class TimeStatistics:
|
|
|
57
56
|
def __enter__(self):
|
|
58
57
|
if self.debug:
|
|
59
58
|
self.time = datetime.now(tz=timezone.utc)
|
|
60
|
-
|
|
59
|
+
logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
|
|
61
60
|
f'Id[{self.index}]')
|
|
62
61
|
|
|
63
62
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
@@ -68,9 +67,9 @@ class TimeStatistics:
|
|
|
68
67
|
hot_time_cost = "Hotspot " + time_cost
|
|
69
68
|
|
|
70
69
|
if cost_time.total_seconds() > self.timeout:
|
|
71
|
-
|
|
70
|
+
logger.info(hot_time_cost)
|
|
72
71
|
else:
|
|
73
|
-
|
|
72
|
+
logger.info(time_cost)
|
|
74
73
|
|
|
75
74
|
|
|
76
75
|
def support_basic_type(data):
|
|
@@ -87,24 +86,24 @@ def dump_data(data, prefix, dump_path):
|
|
|
87
86
|
elif support_basic_type(data):
|
|
88
87
|
if isinstance(data, torch.Tensor) and data.is_meta:
|
|
89
88
|
return
|
|
90
|
-
# dump data may greater than
|
|
89
|
+
# dump data may greater than summary_list collect
|
|
91
90
|
np_save_data(data, prefix, dump_path)
|
|
92
91
|
|
|
93
92
|
|
|
94
|
-
def
|
|
95
|
-
|
|
93
|
+
def save_temp_summary(api_index, single_api_summary, path, lock):
|
|
94
|
+
summary_path = os.path.join(path, f'summary.json')
|
|
96
95
|
lock.acquire()
|
|
97
|
-
with FileOpen(
|
|
98
|
-
json.dump([api_index,
|
|
96
|
+
with FileOpen(summary_path, "a") as f:
|
|
97
|
+
json.dump([api_index, single_api_summary], f)
|
|
99
98
|
f.write('\n')
|
|
100
99
|
lock.release()
|
|
101
100
|
|
|
102
101
|
|
|
103
102
|
def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
|
|
104
103
|
cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs
|
|
105
|
-
|
|
104
|
+
all_summary, func = data_info.all_summary, data_info.func
|
|
106
105
|
npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock
|
|
107
|
-
|
|
106
|
+
single_api_summary = []
|
|
108
107
|
|
|
109
108
|
prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input'
|
|
110
109
|
prefix_output = f'{run_param.aten_api}_{run_param.single_api_index}_output'
|
|
@@ -127,9 +126,9 @@ def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
|
|
|
127
126
|
dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path)
|
|
128
127
|
|
|
129
128
|
if run_param.process_num == 0:
|
|
130
|
-
|
|
129
|
+
all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary)
|
|
131
130
|
else:
|
|
132
|
-
|
|
131
|
+
save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock)
|
|
133
132
|
|
|
134
133
|
|
|
135
134
|
def get_torch_func(run_param):
|
|
@@ -155,32 +154,3 @@ def dispatch_multiprocess(run_param, dispatch_data_info):
|
|
|
155
154
|
def error_call(err):
|
|
156
155
|
logger.error(f'multiprocess {err}')
|
|
157
156
|
|
|
158
|
-
|
|
159
|
-
def save_csv(all_summery, call_stack_list, csv_path):
|
|
160
|
-
df = pd.DataFrame(columns=CSV_COLUMN_NAME)
|
|
161
|
-
|
|
162
|
-
for index, list_data in enumerate(all_summery):
|
|
163
|
-
for data in list_data:
|
|
164
|
-
csv_row_data = {CompareConst.NPU_NAME: data[CompareConst.NPU_NAME],
|
|
165
|
-
CompareConst.BENCH_NAME: data[CompareConst.BENCH_NAME],
|
|
166
|
-
CompareConst.NPU_DTYPE: data[CompareConst.NPU_DTYPE],
|
|
167
|
-
CompareConst.BENCH_DTYPE: data[CompareConst.BENCH_DTYPE],
|
|
168
|
-
CompareConst.NPU_SHAPE: data[CompareConst.NPU_SHAPE],
|
|
169
|
-
CompareConst.BENCH_SHAPE: data[CompareConst.BENCH_SHAPE],
|
|
170
|
-
CompareConst.NPU_MAX: data[CompareConst.NPU_MAX],
|
|
171
|
-
CompareConst.NPU_MIN: data[CompareConst.NPU_MIN],
|
|
172
|
-
CompareConst.NPU_MEAN: data[CompareConst.NPU_MEAN],
|
|
173
|
-
CompareConst.BENCH_MAX: data[CompareConst.BENCH_MAX],
|
|
174
|
-
CompareConst.BENCH_MIN: data[CompareConst.BENCH_MIN],
|
|
175
|
-
CompareConst.BENCH_MEAN: data[CompareConst.BENCH_MEAN],
|
|
176
|
-
CompareConst.COSINE: data[CompareConst.COSINE],
|
|
177
|
-
CompareConst.MAX_ABS_ERR: data[CompareConst.MAX_ABS_ERR],
|
|
178
|
-
CompareConst.MAX_RELATIVE_ERR: data[CompareConst.MAX_RELATIVE_ERR],
|
|
179
|
-
CompareConst.ACCURACY: data[CompareConst.ACCURACY],
|
|
180
|
-
CompareConst.STACK: call_stack_list[index],
|
|
181
|
-
CompareConst.ERROR_MESSAGE: data[CompareConst.ERROR_MESSAGE]}
|
|
182
|
-
row_df = pd.DataFrame.from_dict(csv_row_data, orient='index').T
|
|
183
|
-
df = pd.concat([df, row_df])
|
|
184
|
-
|
|
185
|
-
df.to_csv(csv_path, index=False)
|
|
186
|
-
change_mode(csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -3,15 +3,15 @@ from functools import wraps
|
|
|
3
3
|
import torch
|
|
4
4
|
from prettytable import PrettyTable
|
|
5
5
|
from collections import namedtuple
|
|
6
|
-
from .
|
|
6
|
+
from msprobe.pytorch.common.log import logger
|
|
7
7
|
|
|
8
8
|
def func_log_wrapper():
|
|
9
9
|
def _out_wrapper(func):
|
|
10
10
|
@wraps(func)
|
|
11
11
|
def _in_wrapper(*kargs, **kwargs):
|
|
12
|
-
|
|
12
|
+
logger.info(f"start to run: {func.__name__}")
|
|
13
13
|
x = func(*kargs, **kwargs)
|
|
14
|
-
|
|
14
|
+
logger.info(f"end to run: {func.__name__}")
|
|
15
15
|
return x
|
|
16
16
|
|
|
17
17
|
return _in_wrapper
|
|
@@ -165,7 +165,7 @@ class SingleBenchmarkAccuracyCompare:
|
|
|
165
165
|
def compute_binary_diff(cls, npu_out, bench_out):
|
|
166
166
|
result = torch.equal(npu_out, bench_out)
|
|
167
167
|
if result:
|
|
168
|
-
|
|
168
|
+
logger.info("二进制精度比对通过, 无需单标杆比对法验证")
|
|
169
169
|
return SingleBenchmarkAccuracyResult(result=result, max_abs_diff=0, max_rel_diff=0, error_balance=0)
|
|
170
170
|
|
|
171
171
|
@classmethod
|
|
@@ -301,7 +301,7 @@ class SingleBenchSummary:
|
|
|
301
301
|
table.add_row(["max_rel_diff", self.max_rel_diff, self.error_thd])
|
|
302
302
|
table.add_row(["max_rel_idx", self.max_rel_idx, "-"])
|
|
303
303
|
|
|
304
|
-
|
|
304
|
+
logger.info(table)
|
|
305
305
|
|
|
306
306
|
def to_column_value(self):
|
|
307
307
|
return [self.bench_dtype, self.npu_dtype, self.shape, self.error_balance,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import inspect
|
|
3
|
-
import logging
|
|
4
3
|
import psutil
|
|
5
4
|
import torch
|
|
6
5
|
import numpy as np
|
|
@@ -14,6 +13,7 @@ else:
|
|
|
14
13
|
|
|
15
14
|
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
16
15
|
from msprobe.core.common.file_check import change_mode
|
|
16
|
+
from msprobe.core.common.log import logger
|
|
17
17
|
|
|
18
18
|
cpu_device = torch._C.device("cpu")
|
|
19
19
|
COLOR_RED = '\033[31m'
|
|
@@ -77,7 +77,7 @@ def np_save_data(data, file_name, data_path):
|
|
|
77
77
|
np.save(dump_path, data)
|
|
78
78
|
change_mode(dump_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
79
79
|
except Exception as e:
|
|
80
|
-
|
|
80
|
+
logger.error("save numpy failed, error: {}".format(e))
|
|
81
81
|
finally:
|
|
82
82
|
pass
|
|
83
83
|
|
|
@@ -124,47 +124,6 @@ def data_to_cpu(data, deep, data_cpu):
|
|
|
124
124
|
return data
|
|
125
125
|
|
|
126
126
|
|
|
127
|
-
def get_mp_logger():
|
|
128
|
-
logger = logging.getLogger(__name__)
|
|
129
|
-
if not logger.handlers:
|
|
130
|
-
logger.setLevel(logging.INFO)
|
|
131
|
-
handler = logging.StreamHandler()
|
|
132
|
-
formatter = logging.Formatter('%(asctime)s %(message)s')
|
|
133
|
-
logger.propagate = True
|
|
134
|
-
handler.setFormatter(formatter)
|
|
135
|
-
logger.addHandler(handler)
|
|
136
|
-
return logger.info
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def logger_debug(mesg):
|
|
140
|
-
logger = get_mp_logger()
|
|
141
|
-
logger(f'DEBUG ' + mesg)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def logger_info(mesg):
|
|
145
|
-
logger = get_mp_logger()
|
|
146
|
-
logger(f'INFO ' + mesg)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def logger_warn(mesg):
|
|
150
|
-
logger = get_mp_logger()
|
|
151
|
-
logger(f'{COLOR_YELLOW}WARNING {mesg} {COLOR_RESET}')
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def logger_error(mesg):
|
|
155
|
-
logger = get_mp_logger()
|
|
156
|
-
logger(f'{COLOR_RED}ERROR {mesg} {COLOR_RESET}')
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def logger_user(mesg):
|
|
160
|
-
logger = get_mp_logger()
|
|
161
|
-
logger(mesg)
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def logger_logo():
|
|
165
|
-
logger_user(f'{COLOR_CYAN}{COMPARE_LOGO} {COLOR_RESET}')
|
|
166
|
-
|
|
167
|
-
|
|
168
127
|
def get_sys_info():
|
|
169
128
|
mem = psutil.virtual_memory()
|
|
170
129
|
cpu_percent = psutil.cpu_percent(interval=1)
|