mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import zlib
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from msprobe.core.common.exceptions import MsaccException
|
|
9
|
+
from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
|
|
10
|
+
from msprobe.core.common.log import logger
|
|
11
|
+
from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
|
|
12
|
+
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
|
|
13
|
+
ModuleForwardInputsOutputs, TensorStatInfo
|
|
14
|
+
from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import torch_npu
|
|
18
|
+
except ImportError:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PytorchDataProcessor(BaseDataProcessor):
|
|
23
|
+
pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
|
|
24
|
+
|
|
25
|
+
def __init__(self, config, data_writer):
|
|
26
|
+
super().__init__(config, data_writer)
|
|
27
|
+
self.torch_object_key = {
|
|
28
|
+
"device": self.analyze_device_in_kwargs,
|
|
29
|
+
"dtype": self.analyze_dtype_in_kwargs
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def get_md5_for_tensor(x):
|
|
34
|
+
if x.dtype == torch.bfloat16:
|
|
35
|
+
x = x.float()
|
|
36
|
+
tensor_bytes = x.cpu().detach().numpy().tobytes()
|
|
37
|
+
crc32_hash = zlib.crc32(tensor_bytes)
|
|
38
|
+
return f"{crc32_hash:08x}"
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def analyze_device_in_kwargs(element):
|
|
42
|
+
single_arg = {}
|
|
43
|
+
single_arg.update({'type': "torch.device"})
|
|
44
|
+
if not isinstance(element, str):
|
|
45
|
+
if hasattr(element, "index"):
|
|
46
|
+
device_value = element.type + ":" + str(element.index)
|
|
47
|
+
else:
|
|
48
|
+
device_value = element.type
|
|
49
|
+
single_arg.update({"value": device_value})
|
|
50
|
+
else:
|
|
51
|
+
single_arg.update({"value": element})
|
|
52
|
+
return single_arg
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def analyze_dtype_in_kwargs(element):
|
|
56
|
+
return {"type": "torch.dtype", "value": str(element)}
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def get_stat_info(data):
|
|
60
|
+
tensor_stat = TensorStatInfo()
|
|
61
|
+
if data.is_meta:
|
|
62
|
+
return tensor_stat
|
|
63
|
+
data_clone = data.detach()
|
|
64
|
+
if data_clone.numel() == 0:
|
|
65
|
+
return tensor_stat
|
|
66
|
+
elif data_clone.dtype == torch.bool:
|
|
67
|
+
tensor_stat.max = True in data_clone
|
|
68
|
+
tensor_stat.min = False not in data_clone
|
|
69
|
+
elif not data_clone.shape:
|
|
70
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
|
|
71
|
+
else:
|
|
72
|
+
if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
|
|
73
|
+
data_clone = data_clone.float()
|
|
74
|
+
tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
|
|
75
|
+
tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
|
|
76
|
+
tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
|
|
77
|
+
tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
|
|
78
|
+
return tensor_stat
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def _analyze_torch_size(arg):
|
|
82
|
+
return {"type": "torch.Size", "value": list(arg)}
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def get_special_types(cls):
|
|
86
|
+
return super().get_special_types() + cls.pytorch_special_type
|
|
87
|
+
|
|
88
|
+
def analyze_single_element(self, element, suffix_stack):
|
|
89
|
+
if suffix_stack and suffix_stack[-1] in self.torch_object_key:
|
|
90
|
+
return self.torch_object_key[suffix_stack[-1]](element)
|
|
91
|
+
if isinstance(element, torch.Size):
|
|
92
|
+
return self._analyze_torch_size(element)
|
|
93
|
+
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
94
|
+
if converted_numpy is not element:
|
|
95
|
+
return self._analyze_numpy(converted_numpy, numpy_type)
|
|
96
|
+
if isinstance(element, torch.Tensor):
|
|
97
|
+
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
98
|
+
if isinstance(element, (bool, int, float, str, slice)):
|
|
99
|
+
return self._analyze_builtin(element)
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
def analyze_element(self, element):
|
|
103
|
+
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
104
|
+
|
|
105
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
106
|
+
tensor_stat = self.get_stat_info(tensor)
|
|
107
|
+
tensor_json = {}
|
|
108
|
+
tensor_json.update({'type': 'torch.Tensor'})
|
|
109
|
+
tensor_json.update({'dtype': str(tensor.dtype)})
|
|
110
|
+
tensor_json.update({"shape": tensor.shape})
|
|
111
|
+
tensor_json.update({"Max": tensor_stat.max})
|
|
112
|
+
tensor_json.update({"Min": tensor_stat.min})
|
|
113
|
+
tensor_json.update({"Mean": tensor_stat.mean})
|
|
114
|
+
tensor_json.update({"Norm": tensor_stat.norm})
|
|
115
|
+
tensor_json.update({"requires_grad": tensor.requires_grad})
|
|
116
|
+
if self.config.summary_mode == "md5":
|
|
117
|
+
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
118
|
+
tensor_json.update({"md5": tensor_md5})
|
|
119
|
+
return tensor_json
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class StatisticsDataProcessor(PytorchDataProcessor):
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class TensorDataProcessor(PytorchDataProcessor):
|
|
127
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
128
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
129
|
+
if not path_len_exceeds_limit(file_path):
|
|
130
|
+
torch.save(tensor, file_path)
|
|
131
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
132
|
+
else:
|
|
133
|
+
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
134
|
+
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
135
|
+
single_arg.update({"data_name": dump_data_name})
|
|
136
|
+
return single_arg
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class OverflowCheckDataProcessor(PytorchDataProcessor):
|
|
140
|
+
__slots__ = ["cached_tensors_and_file_paths"]
|
|
141
|
+
|
|
142
|
+
def __init__(self, config, data_writer):
|
|
143
|
+
super().__init__(config, data_writer)
|
|
144
|
+
self.cached_tensors_and_file_paths = {}
|
|
145
|
+
self.real_overflow_dump_times = 0
|
|
146
|
+
self.overflow_nums = config.overflow_num
|
|
147
|
+
self.bits_for_overflow = 8
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def overflow_debug_mode_enable():
|
|
151
|
+
overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
|
|
152
|
+
return overflow_mode == Const.ENV_ENABLE
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def handle_tensor_extremum_nan_inf(data_clone, operator):
|
|
156
|
+
data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
|
|
157
|
+
if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
|
|
158
|
+
return float('nan')
|
|
159
|
+
finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
|
|
160
|
+
if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
|
|
161
|
+
finite_values = data_clone[finite_mask]
|
|
162
|
+
return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
|
|
163
|
+
torch._C._VariableFunctionsClass.min(finite_values).item()
|
|
164
|
+
else:
|
|
165
|
+
data_no_nan = data_clone[~data_nan]
|
|
166
|
+
return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
|
|
167
|
+
torch._C._VariableFunctionsClass.min(data_no_nan).item()
|
|
168
|
+
|
|
169
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
170
|
+
self.has_overflow = False
|
|
171
|
+
api_info_struct = super().analyze_forward(name, module, module_input_output)
|
|
172
|
+
self.maybe_save_overflow_data_and_check_overflow_times()
|
|
173
|
+
return api_info_struct if self.has_overflow else None
|
|
174
|
+
|
|
175
|
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
176
|
+
self.has_overflow = False
|
|
177
|
+
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
178
|
+
self.maybe_save_overflow_data_and_check_overflow_times()
|
|
179
|
+
return api_info_struct if self.has_overflow else None
|
|
180
|
+
|
|
181
|
+
def maybe_save_overflow_data_and_check_overflow_times(self):
|
|
182
|
+
if self.has_overflow:
|
|
183
|
+
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
184
|
+
torch.save(tensor, file_path)
|
|
185
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
186
|
+
self.inc_and_check_overflow_times()
|
|
187
|
+
self.cached_tensors_and_file_paths = {}
|
|
188
|
+
|
|
189
|
+
def inc_and_check_overflow_times(self):
|
|
190
|
+
self.real_overflow_dump_times += 1
|
|
191
|
+
if self.overflow_nums == -1:
|
|
192
|
+
return
|
|
193
|
+
if self.real_overflow_dump_times >= self.overflow_nums:
|
|
194
|
+
raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
|
|
195
|
+
|
|
196
|
+
def check_overflow_npu(self):
|
|
197
|
+
if self.overflow_debug_mode_enalbe():
|
|
198
|
+
float_status = torch.zeros(self.bits_for_overflow).npu()
|
|
199
|
+
result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
|
|
200
|
+
if result.cpu()[0] != 0:
|
|
201
|
+
return True
|
|
202
|
+
else:
|
|
203
|
+
return False
|
|
204
|
+
else:
|
|
205
|
+
return torch_npu._C._check_overflow_npu()
|
|
206
|
+
|
|
207
|
+
def clear_overflow_npu(self):
|
|
208
|
+
if self.overflow_debug_mode_enable():
|
|
209
|
+
float_status = torch.zeros(self.bits_for_overflow).npu()
|
|
210
|
+
torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
|
|
211
|
+
else:
|
|
212
|
+
torch_npu._C._clear_overflow_npu()
|
|
213
|
+
|
|
214
|
+
def _analyze_maybe_overflow_tensor(self, tensor_json, tensor):
|
|
215
|
+
data_clone = tensor.detach()
|
|
216
|
+
if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
|
|
217
|
+
if tensor_json['Max'] is None:
|
|
218
|
+
return
|
|
219
|
+
if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
|
|
220
|
+
tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
|
|
221
|
+
self.has_overflow = True
|
|
222
|
+
if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
|
|
223
|
+
tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
|
|
224
|
+
self.has_overflow = True
|
|
225
|
+
else:
|
|
226
|
+
self.has_overflow = self.check_overflow_npu()
|
|
227
|
+
if self.has_overflow:
|
|
228
|
+
self.clear_overflow_npu()
|
|
229
|
+
|
|
230
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
231
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
232
|
+
if not path_len_exceeds_limit(file_path):
|
|
233
|
+
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
234
|
+
else:
|
|
235
|
+
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
236
|
+
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
237
|
+
self._analyze_maybe_overflow_tensor(single_arg, tensor)
|
|
238
|
+
single_arg.update({"data_name": dump_data_name})
|
|
239
|
+
return single_arg
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class FreeBenchmarkDataProcessor(PytorchDataProcessor):
|
|
243
|
+
|
|
244
|
+
def __init__(self, config, data_writer):
|
|
245
|
+
super().__init__(config, data_writer)
|
|
246
|
+
self.checker = FreeBenchmarkCheck(config=config)
|
|
247
|
+
self._return_forward_new_output = None
|
|
248
|
+
self._forward_new_output = None
|
|
249
|
+
|
|
250
|
+
def update_iter(self, current_iter):
|
|
251
|
+
super().update_iter(current_iter)
|
|
252
|
+
self.checker.update_iter(current_iter)
|
|
253
|
+
|
|
254
|
+
def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
|
|
255
|
+
if not unequal_rows:
|
|
256
|
+
return
|
|
257
|
+
for row in unequal_rows:
|
|
258
|
+
data_dict = asdict(row)
|
|
259
|
+
self.data_writer.write_data_to_csv(
|
|
260
|
+
data_dict.values(),
|
|
261
|
+
data_dict.keys(),
|
|
262
|
+
self.data_writer.free_benchmark_file_path
|
|
263
|
+
)
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
267
|
+
self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
|
|
268
|
+
|
|
269
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
270
|
+
new_output, unequal_rows = self.checker.forward(
|
|
271
|
+
name,
|
|
272
|
+
module,
|
|
273
|
+
module_input_output.args,
|
|
274
|
+
module_input_output.kwargs,
|
|
275
|
+
module_input_output.output,
|
|
276
|
+
)
|
|
277
|
+
self.update_unequal_rows(unequal_rows)
|
|
278
|
+
if self.checker.if_fix():
|
|
279
|
+
self._return_forward_new_output = True
|
|
280
|
+
self._forward_new_output = new_output
|
|
281
|
+
|
|
282
|
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
283
|
+
self.checker.backward(name, module, module_input_output.grad_output)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class KernelDumpDataProcessor(PytorchDataProcessor):
|
|
287
|
+
forward_init_status = False
|
|
288
|
+
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
289
|
+
|
|
290
|
+
def __init__(self, config, data_writer):
|
|
291
|
+
super().__init__(config, data_writer)
|
|
292
|
+
|
|
293
|
+
def analyze_forward(self, name, module, module_input_output):
|
|
294
|
+
if self.config.is_forward_acl_dump:
|
|
295
|
+
self.forward_acl_dump(name, module, module_input_output)
|
|
296
|
+
else:
|
|
297
|
+
self.dump_mode_backward_acl_dump(name, module, module_input_output)
|
|
298
|
+
|
|
299
|
+
def forward_acl_dump(self, name, module, module_input_output):
|
|
300
|
+
if not KernelDumpDataProcessor.forward_init_status:
|
|
301
|
+
KernelDumpDataProcessor.forward_init_status = True
|
|
302
|
+
torch_npu.npu.synchronize()
|
|
303
|
+
torch_npu.npu.init_dump()
|
|
304
|
+
torch_npu.npu.set_dump(self.config.acl_config)
|
|
305
|
+
torch_npu.npu.synchronize()
|
|
306
|
+
if self.op_need_trigger(name):
|
|
307
|
+
module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
|
|
308
|
+
else:
|
|
309
|
+
module.forward(*module_input_output.args, **module_input_output.kwargs)
|
|
310
|
+
torch_npu.npu.synchronize()
|
|
311
|
+
torch_npu.npu.finalize_dump()
|
|
312
|
+
torch_npu.npu.synchronize()
|
|
313
|
+
KernelDumpDataProcessor.forward_init_status = False
|
|
314
|
+
logger.info("Dump %s op file." % name)
|
|
315
|
+
|
|
316
|
+
def acl_backward_dump_status(self, output, grad, module_name):
|
|
317
|
+
if isinstance(output, torch.Tensor):
|
|
318
|
+
output.backward(grad, retain_graph=True)
|
|
319
|
+
return True
|
|
320
|
+
|
|
321
|
+
for api_name in KernelDumpDataProcessor.multi_output_apis:
|
|
322
|
+
if api_name in module_name:
|
|
323
|
+
output[0].backward(grad, retain_graph=True)
|
|
324
|
+
return True
|
|
325
|
+
return False
|
|
326
|
+
|
|
327
|
+
def dump_mode_backward_acl_dump(self, name, module, module_input_output):
|
|
328
|
+
grad_path = self.config.backward_input.get(name)
|
|
329
|
+
if not KernelDumpDataProcessor.forward_init_status:
|
|
330
|
+
KernelDumpDataProcessor.forward_init_status = True
|
|
331
|
+
output = module.forward(*module_input_output.args, **module_input_output.kwargs)
|
|
332
|
+
grad = torch.load(grad_path).to("npu").requires_grad_()
|
|
333
|
+
torch_npu.npu.init_dump()
|
|
334
|
+
torch_npu.npu.set_dump(self.config.acl_config)
|
|
335
|
+
torch_npu.npu.synchronize()
|
|
336
|
+
if not self.acl_backward_dump_status(output, grad, name):
|
|
337
|
+
logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
|
|
338
|
+
"you can manually construct a single API backward case for ACL dump.".format(
|
|
339
|
+
name))
|
|
340
|
+
torch_npu.npu.synchronize()
|
|
341
|
+
torch_npu.npu.finalize_dump()
|
|
342
|
+
KernelDumpDataProcessor.forward_init_status = False
|
|
343
|
+
logger.info("Dump %s op file." % name)
|
|
344
|
+
|
|
345
|
+
def op_need_trigger(self, module_name):
|
|
346
|
+
return 'Tensor.__getitem__.' in module_name
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import csv
|
|
3
|
+
import fcntl
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from msprobe.core.common.file_check import change_mode
|
|
8
|
+
from msprobe.core.common.log import logger
|
|
9
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DataWriter:
|
|
13
|
+
|
|
14
|
+
def __init__(self, init_json=None) -> None:
|
|
15
|
+
self.dump_count = 0
|
|
16
|
+
self.init_json = init_json
|
|
17
|
+
self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
|
|
18
|
+
self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
|
|
19
|
+
self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
|
|
20
|
+
self.free_benchmark_file_path = None
|
|
21
|
+
self.dump_tensor_data_dir = None
|
|
22
|
+
self.buffer_size = 1000
|
|
23
|
+
self.cache_data = {Const.DATA: {}}
|
|
24
|
+
self.cache_stack = {}
|
|
25
|
+
self.cache_construct = {}
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def write_data_to_csv(result: list, result_header: tuple, file_path: str):
|
|
29
|
+
if not result:
|
|
30
|
+
return
|
|
31
|
+
is_exists = os.path.exists(file_path)
|
|
32
|
+
append = "a+" if is_exists else "w+"
|
|
33
|
+
with os.fdopen(
|
|
34
|
+
os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
|
|
35
|
+
) as csv_file:
|
|
36
|
+
spawn_writer = csv.writer(csv_file)
|
|
37
|
+
if not is_exists:
|
|
38
|
+
spawn_writer.writerow(result_header)
|
|
39
|
+
spawn_writer.writerows([result,])
|
|
40
|
+
|
|
41
|
+
def initialize_json_file(self, **kwargs):
|
|
42
|
+
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
|
|
43
|
+
with os.fdopen(
|
|
44
|
+
os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
|
|
45
|
+
) as f:
|
|
46
|
+
json.dump(kwargs, f)
|
|
47
|
+
|
|
48
|
+
if os.path.exists(self.stack_file_path):
|
|
49
|
+
os.remove(self.stack_file_path)
|
|
50
|
+
Path(self.stack_file_path).touch()
|
|
51
|
+
change_mode(self.stack_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
52
|
+
|
|
53
|
+
if os.path.exists(self.construct_file_path):
|
|
54
|
+
os.remove(self.construct_file_path)
|
|
55
|
+
Path(self.construct_file_path).touch()
|
|
56
|
+
change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
57
|
+
|
|
58
|
+
def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
|
|
59
|
+
free_benchmark_file_path):
|
|
60
|
+
self.dump_file_path = dump_file_path
|
|
61
|
+
self.stack_file_path = stack_file_path
|
|
62
|
+
self.construct_file_path = construct_file_path
|
|
63
|
+
self.dump_tensor_data_dir = dump_data_dir
|
|
64
|
+
self.free_benchmark_file_path = free_benchmark_file_path
|
|
65
|
+
|
|
66
|
+
def update_data(self, new_data):
|
|
67
|
+
key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
|
|
68
|
+
if key in self.cache_data[Const.DATA]:
|
|
69
|
+
self.cache_data[Const.DATA][key].update(new_data[key])
|
|
70
|
+
else:
|
|
71
|
+
self.cache_data[Const.DATA].update(new_data)
|
|
72
|
+
|
|
73
|
+
def flush_data_when_buffer_is_full(self):
|
|
74
|
+
if len(self.cache_data[Const.DATA]) >= self.buffer_size:
|
|
75
|
+
self.write_data_json(self.dump_file_path)
|
|
76
|
+
|
|
77
|
+
def update_stack(self, new_data):
|
|
78
|
+
self.cache_stack.update(new_data)
|
|
79
|
+
|
|
80
|
+
def update_construct(self, new_data):
|
|
81
|
+
self.cache_construct.update(new_data)
|
|
82
|
+
|
|
83
|
+
def write_data_json(self, file_path):
|
|
84
|
+
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
|
|
85
|
+
if Path(file_path).exists() and os.path.getsize(file_path) > 0:
|
|
86
|
+
with open(file_path, "r+") as f:
|
|
87
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
88
|
+
data_to_write = json.load(f)
|
|
89
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
90
|
+
else:
|
|
91
|
+
self.init_json['data_path'] = self.dump_tensor_data_dir
|
|
92
|
+
data_to_write = self.init_json
|
|
93
|
+
data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
|
|
94
|
+
with open(file_path, 'w+') as f:
|
|
95
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
96
|
+
json.dump(data_to_write, f, indent=1)
|
|
97
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
98
|
+
|
|
99
|
+
self.cache_data[Const.DATA].clear()
|
|
100
|
+
|
|
101
|
+
def write_stack_info_json(self, file_path):
|
|
102
|
+
with open(file_path, 'w+') as f:
|
|
103
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
104
|
+
json.dump(self.cache_stack, f, indent=1)
|
|
105
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
106
|
+
|
|
107
|
+
def write_construct_info_json(self, file_path):
|
|
108
|
+
with open(file_path, 'w+') as f:
|
|
109
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
|
110
|
+
json.dump(self.cache_construct, f, indent=1)
|
|
111
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
|
112
|
+
|
|
113
|
+
def write_json(self):
|
|
114
|
+
self.write_data_json(self.dump_file_path)
|
|
115
|
+
self.write_stack_info_json(self.stack_file_path)
|
|
116
|
+
self.write_construct_info_json(self.construct_file_path)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from msprobe.core.common.exceptions import ScopeException
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def build_scope(scope_class, scope=None, api_list=None):
|
|
7
|
+
if not scope and not api_list:
|
|
8
|
+
return None
|
|
9
|
+
if scope is None:
|
|
10
|
+
scope = []
|
|
11
|
+
if api_list is None:
|
|
12
|
+
api_list = []
|
|
13
|
+
if scope_class:
|
|
14
|
+
return scope_class(scope, api_list)
|
|
15
|
+
return build_range_scope_according_to_scope_name(scope, api_list)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_range_scope_according_to_scope_name(scope, api_list):
|
|
19
|
+
api_range_scope = APIRangeScope(scope, api_list)
|
|
20
|
+
module_range_scope = ModuleRangeScope(scope, api_list)
|
|
21
|
+
if not scope: # 如果没有scope参数则用哪类scope都一样
|
|
22
|
+
return api_range_scope
|
|
23
|
+
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
24
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
|
|
25
|
+
elif api_range_scope.is_valid:
|
|
26
|
+
return api_range_scope
|
|
27
|
+
elif module_range_scope.is_valid:
|
|
28
|
+
return module_range_scope
|
|
29
|
+
else:
|
|
30
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseScope(ABC):
|
|
34
|
+
Module_Type_Module = "Module"
|
|
35
|
+
Module_Type_API = "api"
|
|
36
|
+
|
|
37
|
+
def __init__(self, scope, api_list):
|
|
38
|
+
scope, api_list = self.rectify_args(scope, api_list)
|
|
39
|
+
self.scope = scope
|
|
40
|
+
self.api_list = api_list
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def rectify_args(scope, api_list):
|
|
44
|
+
if not isinstance(api_list, list):
|
|
45
|
+
raise ScopeException(ScopeException.InvalidApiStr,
|
|
46
|
+
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
47
|
+
for api in api_list:
|
|
48
|
+
if not isinstance(api, str):
|
|
49
|
+
raise ScopeException(ScopeException.InvalidApiStr,
|
|
50
|
+
f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
|
|
51
|
+
if isinstance(scope, str):
|
|
52
|
+
scope = [scope]
|
|
53
|
+
return scope, api_list
|
|
54
|
+
if not isinstance(scope, list):
|
|
55
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
56
|
+
f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
|
|
57
|
+
for s in scope:
|
|
58
|
+
if not isinstance(s, str):
|
|
59
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
60
|
+
f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
|
|
61
|
+
return scope, api_list
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def check(self, name):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def check_api_list(self, api_name):
|
|
68
|
+
if not self.api_list:
|
|
69
|
+
return True
|
|
70
|
+
for api_str in self.api_list:
|
|
71
|
+
if api_str in api_name:
|
|
72
|
+
return True
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ListScope(BaseScope):
|
|
77
|
+
@staticmethod
|
|
78
|
+
def rectify_args(scope, api_list):
|
|
79
|
+
if scope and api_list:
|
|
80
|
+
raise ScopeException(ScopeException.ArgConflict,
|
|
81
|
+
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
82
|
+
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
83
|
+
|
|
84
|
+
def check(self, module_name):
|
|
85
|
+
if not self.scope or module_name in self.scope:
|
|
86
|
+
return self.check_api_list(module_name)
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RangeScope(BaseScope, ABC):
|
|
91
|
+
|
|
92
|
+
def __init__(self, *args):
|
|
93
|
+
super().__init__(*args)
|
|
94
|
+
self.in_scope = False
|
|
95
|
+
self.is_valid = self.check_scope_is_valid()
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def rectify_args(scope, api_list):
|
|
100
|
+
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
101
|
+
if isinstance(scope, list):
|
|
102
|
+
if len(scope) == 1:
|
|
103
|
+
scope.append(scope[0])
|
|
104
|
+
elif len(scope) > 2:
|
|
105
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
106
|
+
f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
|
|
107
|
+
|
|
108
|
+
return scope, api_list
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def check_scope_is_valid(self):
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
def begin_module(self, module_name):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
def end_module(self, module_name):
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class APIRangeScope(RangeScope):
|
|
122
|
+
def check_scope_is_valid(self):
|
|
123
|
+
if not self.scope:
|
|
124
|
+
return True
|
|
125
|
+
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
126
|
+
if scope_start_type == BaseScope.Module_Type_Module:
|
|
127
|
+
return False
|
|
128
|
+
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
129
|
+
if scope_stop_type == BaseScope.Module_Type_Module:
|
|
130
|
+
return False
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def check(self, api_name):
|
|
134
|
+
if self.scope and api_name == self.scope[0]:
|
|
135
|
+
self.in_scope = True
|
|
136
|
+
|
|
137
|
+
if not self.scope or self.in_scope:
|
|
138
|
+
result = self.check_api_list(api_name)
|
|
139
|
+
else:
|
|
140
|
+
result = False
|
|
141
|
+
|
|
142
|
+
if self.scope and api_name == self.scope[1]:
|
|
143
|
+
self.in_scope = False
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class ModuleRangeScope(RangeScope):
|
|
148
|
+
"""
|
|
149
|
+
模块与api不同的是,模块内部还有子结构需要dump,
|
|
150
|
+
需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
|
|
151
|
+
在这些hook触发时调用begin_module和end_module做区间控制
|
|
152
|
+
"""
|
|
153
|
+
def check_scope_is_valid(self):
|
|
154
|
+
if not self.scope:
|
|
155
|
+
return True
|
|
156
|
+
scope_start_type = self.scope[0].split(Const.SEP)[0]
|
|
157
|
+
scope_stop_type = self.scope[1].split(Const.SEP)[0]
|
|
158
|
+
if scope_start_type == BaseScope.Module_Type_Module and \
|
|
159
|
+
scope_stop_type == BaseScope.Module_Type_Module:
|
|
160
|
+
return True
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def begin_module(self, module_name):
|
|
164
|
+
if not self.scope:
|
|
165
|
+
return
|
|
166
|
+
if module_name == self.scope[0]:
|
|
167
|
+
self.in_scope = True
|
|
168
|
+
|
|
169
|
+
def end_module(self, module_name):
|
|
170
|
+
if not self.scope:
|
|
171
|
+
return
|
|
172
|
+
if module_name == self.scope[1]:
|
|
173
|
+
self.in_scope = False
|
|
174
|
+
|
|
175
|
+
def check(self, module_name):
|
|
176
|
+
if not self.scope or self.in_scope:
|
|
177
|
+
return self.check_api_list(module_name)
|
|
178
|
+
return False
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
File without changes
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DebuggerConfig:
|
|
5
|
+
convert_map = {
|
|
6
|
+
"L0": "cell",
|
|
7
|
+
"L1": "api",
|
|
8
|
+
"L2": 'kernel'
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
def __init__(self, common_config, task_config):
|
|
12
|
+
self.dump_path = common_config.dump_path
|
|
13
|
+
self.task = common_config.task
|
|
14
|
+
self.rank = [] if not common_config.rank else common_config.rank
|
|
15
|
+
self.step = [] if not common_config.step else common_config.step
|
|
16
|
+
if not common_config.level:
|
|
17
|
+
common_config.level = "L1"
|
|
18
|
+
self.level = DebuggerConfig.convert_map[common_config.level]
|
|
19
|
+
self.list = [] if not task_config.list else task_config.list
|
|
20
|
+
self.data_mode = [] if not task_config.data_mode else task_config.data_mode
|
|
21
|
+
self.file_format = task_config.file_format
|
|
22
|
+
self.check_mode = task_config.check_mode
|
|
23
|
+
|
|
24
|
+
self.check()
|
|
25
|
+
|
|
26
|
+
def check(self):
|
|
27
|
+
if not self.dump_path:
|
|
28
|
+
raise Exception("Dump path is empty.")
|
|
29
|
+
if not os.path.isabs(self.dump_path):
|
|
30
|
+
raise Exception("Dump path must be absolute path.")
|
|
31
|
+
if not self.task:
|
|
32
|
+
self.task = "statistics"
|
|
33
|
+
if not self.level:
|
|
34
|
+
raise Exception("level must be L0, L1 or L2")
|
|
35
|
+
if not self.file_format:
|
|
36
|
+
self.file_format = "npy"
|
|
37
|
+
if not self.check_mode:
|
|
38
|
+
self.check_mode = "all"
|
|
39
|
+
self._check_rank()
|
|
40
|
+
self._check_step()
|
|
41
|
+
return True
|
|
42
|
+
|
|
43
|
+
def _check_rank(self):
|
|
44
|
+
for rank_id in self.rank:
|
|
45
|
+
if not isinstance(rank_id, int) or rank_id < 0:
|
|
46
|
+
raise ValueError(f"rank {self.rank} must be a positive integer.")
|
|
47
|
+
|
|
48
|
+
def _check_step(self):
|
|
49
|
+
for s in self.step:
|
|
50
|
+
if not isinstance(s, int):
|
|
51
|
+
raise ValueError(f"step element {s} should be int")
|