mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
|
|
2
1
|
import os
|
|
3
2
|
|
|
4
|
-
from msprobe.core.data_dump.scope import
|
|
3
|
+
from msprobe.core.data_dump.scope import build_scope, ListScope
|
|
5
4
|
from msprobe.core.data_dump.json_writer import DataWriter
|
|
6
5
|
from msprobe.core.common.log import logger
|
|
7
6
|
from msprobe.core.common.const import Const
|
|
@@ -21,7 +20,8 @@ class DataCollector:
|
|
|
21
20
|
self.config = config
|
|
22
21
|
self.data_writer = DataWriter()
|
|
23
22
|
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
|
|
24
|
-
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
23
|
+
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
|
|
24
|
+
if self.config.framework == Const.PT_FRAMEWORK else None
|
|
25
25
|
self.module_count = {}
|
|
26
26
|
if self.config.task == Const.FREE_BENCHMARK:
|
|
27
27
|
self.scope = build_scope(ListScope, self.config.scope, self.config.list)
|
|
@@ -35,7 +35,7 @@ class DataCollector:
|
|
|
35
35
|
@property
|
|
36
36
|
def dump_file_path(self):
|
|
37
37
|
return self.data_writer.dump_file_path
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
@staticmethod
|
|
40
40
|
def check_scope_and_pid(scope, name, pid):
|
|
41
41
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
@@ -43,10 +43,10 @@ class DataCollector:
|
|
|
43
43
|
@staticmethod
|
|
44
44
|
def is_inplace(module):
|
|
45
45
|
return getattr(module, "op_is_inplace", False)
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
def if_return_forward_new_output(self):
|
|
48
48
|
return self.data_processor.if_return_forward_new_output()
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
def get_forward_new_output(self):
|
|
51
51
|
return self.data_processor.get_forward_new_output()
|
|
52
52
|
|
|
@@ -71,12 +71,11 @@ class DataCollector:
|
|
|
71
71
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
72
72
|
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
73
73
|
self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
|
|
74
|
-
if not self.is_inplace(module):
|
|
74
|
+
if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
|
|
75
75
|
return
|
|
76
76
|
logger.info(f"API {name} is inplace.")
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
self.update_data(data_info)
|
|
77
|
+
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
78
|
+
self.handle_data(name, data_info)
|
|
80
79
|
|
|
81
80
|
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
82
81
|
self.update_construct(name)
|
|
@@ -88,8 +87,11 @@ class DataCollector:
|
|
|
88
87
|
else:
|
|
89
88
|
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
|
|
90
89
|
if self.config.level == "L2":
|
|
91
|
-
return
|
|
90
|
+
return
|
|
92
91
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
92
|
+
if self.data_processor.is_terminated:
|
|
93
|
+
self.handle_data(name, data_info, use_buffer=False)
|
|
94
|
+
raise Exception("[msprobe] exit")
|
|
93
95
|
self.handle_data(name, data_info)
|
|
94
96
|
|
|
95
97
|
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
@@ -98,43 +100,45 @@ class DataCollector:
|
|
|
98
100
|
return
|
|
99
101
|
|
|
100
102
|
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
103
|
+
if self.data_processor.is_terminated:
|
|
104
|
+
self.handle_data(name, data_info, use_buffer=False)
|
|
105
|
+
raise Exception("[msprobe] exit")
|
|
106
|
+
self.handle_data(name, data_info)
|
|
107
|
+
|
|
108
|
+
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
109
|
+
self.update_construct(name)
|
|
110
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
|
|
114
|
+
self.handle_data(name, data_info)
|
|
115
|
+
|
|
116
|
+
def backward_output_data_collect(self, name, module, pid, module_input_output):
|
|
117
|
+
self.update_construct(name)
|
|
118
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
|
|
101
122
|
self.handle_data(name, data_info)
|
|
102
123
|
|
|
103
124
|
def update_construct(self, name):
|
|
104
|
-
if self.config.level not in DataCollector.level_without_construct:
|
|
125
|
+
if self.config.framework == Const.PT_FRAMEWORK and self.config.level not in DataCollector.level_without_construct:
|
|
105
126
|
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
106
127
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
107
128
|
|
|
108
|
-
def handle_data(self, name, data_info):
|
|
109
|
-
msg = f"msProbe is collecting data on {name}. "
|
|
129
|
+
def handle_data(self, name, data_info, use_buffer=True):
|
|
110
130
|
if data_info:
|
|
131
|
+
msg = f"msprobe is collecting data on {name}. "
|
|
111
132
|
msg = self.update_data(data_info, msg)
|
|
112
133
|
logger.info(msg)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def module_count_func(self, name, name_template):
|
|
116
|
-
module_name = name.split(Const.SEP)[-3]
|
|
117
|
-
if "forward" in name_template:
|
|
118
|
-
if module_name not in self.module_count:
|
|
119
|
-
self.module_count[module_name] = [0, [0]]
|
|
120
|
-
else:
|
|
121
|
-
if self.module_count[module_name][-1] and \
|
|
122
|
-
self.module_count[module_name][0] != self.module_count[module_name][-1][-1]:
|
|
123
|
-
self.module_count[module_name][-1].pop()
|
|
124
|
-
self.module_count[module_name][0] += 1
|
|
125
|
-
self.module_count[module_name][-1].append(self.module_count[module_name][0])
|
|
126
|
-
index = self.module_count[module_name][0]
|
|
134
|
+
if use_buffer:
|
|
135
|
+
self.data_writer.flush_data_when_buffer_is_full()
|
|
127
136
|
else:
|
|
128
|
-
|
|
129
|
-
if not backward_stack:
|
|
130
|
-
index = "abnormal"
|
|
131
|
-
else:
|
|
132
|
-
index = backward_stack.pop()
|
|
133
|
-
return index
|
|
137
|
+
self.write_json()
|
|
134
138
|
|
|
135
139
|
def update_dump_paths(self, *args):
|
|
136
140
|
self.data_writer.update_dump_paths(*args)
|
|
137
141
|
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
|
|
138
|
-
|
|
142
|
+
|
|
139
143
|
def update_iter(self, current_iter):
|
|
140
144
|
self.data_processor.update_iter(current_iter)
|
|
@@ -35,11 +35,29 @@ class ModuleBackwardInputsOutputs:
|
|
|
35
35
|
@property
|
|
36
36
|
def grad_input_tuple(self):
|
|
37
37
|
return convert_tuple(self.grad_input)
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
@property
|
|
40
40
|
def grad_output_tuple(self):
|
|
41
|
-
return convert_tuple(self.grad_output)
|
|
42
|
-
|
|
41
|
+
return convert_tuple(self.grad_output)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ModuleBackwardInputs:
|
|
46
|
+
grad_input: Optional[Tuple]
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def grad_input_tuple(self):
|
|
50
|
+
return convert_tuple(self.grad_input)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ModuleBackwardOutputs:
|
|
55
|
+
grad_output: Optional[Tuple]
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def grad_output_tuple(self):
|
|
59
|
+
return convert_tuple(self.grad_output)
|
|
60
|
+
|
|
43
61
|
|
|
44
62
|
class TensorStatInfo:
|
|
45
63
|
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
@@ -53,7 +71,7 @@ class BaseDataProcessor:
|
|
|
53
71
|
_recursive_key_stack = []
|
|
54
72
|
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
55
73
|
bool, int, float, str, slice)
|
|
56
|
-
|
|
74
|
+
|
|
57
75
|
def __init__(self, config, data_writer):
|
|
58
76
|
self.data_writer = data_writer
|
|
59
77
|
self.config = config
|
|
@@ -65,11 +83,15 @@ class BaseDataProcessor:
|
|
|
65
83
|
self.current_iter = 0
|
|
66
84
|
self._return_forward_new_output = False
|
|
67
85
|
self._forward_new_output = None
|
|
68
|
-
|
|
86
|
+
|
|
69
87
|
@property
|
|
70
88
|
def data_path(self):
|
|
71
89
|
return self.data_writer.dump_tensor_data_dir
|
|
72
|
-
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def is_terminated(self):
|
|
93
|
+
return False
|
|
94
|
+
|
|
73
95
|
@staticmethod
|
|
74
96
|
def analyze_api_call_stack(name):
|
|
75
97
|
stack_str = []
|
|
@@ -87,7 +109,17 @@ class BaseDataProcessor:
|
|
|
87
109
|
stack_str.append(stack_line)
|
|
88
110
|
stack_info_struct = {name: stack_str}
|
|
89
111
|
return stack_info_struct
|
|
90
|
-
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def transfer_type(data):
|
|
115
|
+
dtype = str(type(data))
|
|
116
|
+
if 'int' in dtype:
|
|
117
|
+
return int(data)
|
|
118
|
+
elif 'float' in dtype:
|
|
119
|
+
return float(data)
|
|
120
|
+
else:
|
|
121
|
+
return data
|
|
122
|
+
|
|
91
123
|
@staticmethod
|
|
92
124
|
def _convert_numpy_to_builtin(arg):
|
|
93
125
|
type_mapping = {
|
|
@@ -103,26 +135,15 @@ class BaseDataProcessor:
|
|
|
103
135
|
if isinstance(arg, numpy_type):
|
|
104
136
|
return builtin_type(arg), type(arg).__name__
|
|
105
137
|
return arg, ''
|
|
106
|
-
|
|
138
|
+
|
|
107
139
|
@staticmethod
|
|
108
140
|
def _analyze_numpy(value, numpy_type):
|
|
109
141
|
return {"type": numpy_type, "value": value}
|
|
110
|
-
|
|
111
|
-
@staticmethod
|
|
112
|
-
def _analyze_builtin(arg):
|
|
113
|
-
single_arg = {}
|
|
114
|
-
if isinstance(arg, slice):
|
|
115
|
-
single_arg.update({"type": "slice"})
|
|
116
|
-
single_arg.update({"value": [arg.start, arg.stop, arg.step]})
|
|
117
|
-
else:
|
|
118
|
-
single_arg.update({"type": type(arg).__name__})
|
|
119
|
-
single_arg.update({"value": arg})
|
|
120
|
-
return single_arg
|
|
121
|
-
|
|
142
|
+
|
|
122
143
|
@classmethod
|
|
123
144
|
def get_special_types(cls):
|
|
124
145
|
return cls.special_type
|
|
125
|
-
|
|
146
|
+
|
|
126
147
|
@classmethod
|
|
127
148
|
def recursive_apply_transform(cls, args, transform):
|
|
128
149
|
if isinstance(args, cls.get_special_types()):
|
|
@@ -177,13 +198,17 @@ class BaseDataProcessor:
|
|
|
177
198
|
return (Const.ALL in self.config.data_mode or
|
|
178
199
|
forward_backward in self.config.data_mode or
|
|
179
200
|
input_output in self.config.data_mode)
|
|
180
|
-
|
|
181
|
-
def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
|
|
201
|
+
|
|
202
|
+
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
182
203
|
pass
|
|
183
204
|
|
|
205
|
+
def analyze_element(self, element):
|
|
206
|
+
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
207
|
+
|
|
184
208
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
185
209
|
api_info_struct = {}
|
|
186
|
-
|
|
210
|
+
# check whether data_mode contains forward or input
|
|
211
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
187
212
|
api_info_struct[name] = {}
|
|
188
213
|
self.api_data_category = Const.INPUT
|
|
189
214
|
args_info_list = self.analyze_element(module_input_output.args_tuple)
|
|
@@ -192,13 +217,14 @@ class BaseDataProcessor:
|
|
|
192
217
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
193
218
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
194
219
|
|
|
195
|
-
|
|
220
|
+
# check whether data_mode contains forward or output
|
|
221
|
+
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
196
222
|
api_info_struct[name] = api_info_struct.get(name, {})
|
|
197
223
|
self.api_data_category = Const.OUTPUT
|
|
198
224
|
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
199
225
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
200
226
|
return api_info_struct
|
|
201
|
-
|
|
227
|
+
|
|
202
228
|
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
203
229
|
api_info_struct = {}
|
|
204
230
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
@@ -210,7 +236,7 @@ class BaseDataProcessor:
|
|
|
210
236
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
211
237
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
212
238
|
return api_info_struct
|
|
213
|
-
|
|
239
|
+
|
|
214
240
|
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
|
|
215
241
|
concat_args = module_input_output.concat_args_and_kwargs()
|
|
216
242
|
api_info_struct = {}
|
|
@@ -220,26 +246,48 @@ class BaseDataProcessor:
|
|
|
220
246
|
output_info_list = self.analyze_element(concat_args)
|
|
221
247
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
222
248
|
return api_info_struct
|
|
223
|
-
|
|
249
|
+
|
|
224
250
|
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
225
251
|
api_info_struct = {}
|
|
226
|
-
if self.is_dump_for_data_mode(Const.BACKWARD, Const.
|
|
252
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
|
|
227
253
|
api_info_struct[name] = {}
|
|
228
|
-
self.api_data_category = Const.
|
|
254
|
+
self.api_data_category = Const.INPUT
|
|
229
255
|
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
|
|
230
|
-
api_info_struct[name][Const.
|
|
256
|
+
api_info_struct[name][Const.INPUT] = input_info_list
|
|
231
257
|
|
|
232
|
-
if self.is_dump_for_data_mode(Const.BACKWARD, Const.
|
|
258
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
|
|
233
259
|
api_info_struct[name] = api_info_struct.get(name, {})
|
|
234
|
-
self.api_data_category = Const.
|
|
260
|
+
self.api_data_category = Const.OUTPUT
|
|
235
261
|
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
|
|
236
|
-
api_info_struct[name][Const.
|
|
262
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
263
|
+
|
|
264
|
+
return api_info_struct
|
|
265
|
+
|
|
266
|
+
def analyze_backward_input(self, name, module,
|
|
267
|
+
module_input_output: ModuleBackwardInputs):
|
|
268
|
+
api_info_struct = {}
|
|
269
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
|
|
270
|
+
api_info_struct[name] = {}
|
|
271
|
+
self.api_data_category = Const.INPUT
|
|
272
|
+
|
|
273
|
+
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
|
|
274
|
+
api_info_struct[name][Const.INPUT] = input_info_list
|
|
275
|
+
return api_info_struct
|
|
237
276
|
|
|
277
|
+
def analyze_backward_output(self, name, module,
|
|
278
|
+
module_input_output: ModuleBackwardOutputs):
|
|
279
|
+
api_info_struct = {}
|
|
280
|
+
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
|
|
281
|
+
api_info_struct[name] = {}
|
|
282
|
+
self.api_data_category = Const.OUTPUT
|
|
283
|
+
|
|
284
|
+
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
|
|
285
|
+
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
238
286
|
return api_info_struct
|
|
239
287
|
|
|
240
288
|
def get_save_file_path(self, suffix):
|
|
241
|
-
file_format =
|
|
289
|
+
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
242
290
|
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
243
|
-
suffix +
|
|
291
|
+
suffix + file_format)
|
|
244
292
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
245
|
-
return dump_data_name, file_path
|
|
293
|
+
return dump_data_name, file_path
|
|
@@ -4,7 +4,7 @@ from msprobe.core.common.const import Const
|
|
|
4
4
|
class DataProcessorFactory:
|
|
5
5
|
_data_processor = {}
|
|
6
6
|
_module_processor = {}
|
|
7
|
-
|
|
7
|
+
|
|
8
8
|
@classmethod
|
|
9
9
|
def register_processor(cls, framework, task, processor_class):
|
|
10
10
|
key = (framework, task)
|
|
@@ -13,7 +13,7 @@ class DataProcessorFactory:
|
|
|
13
13
|
@classmethod
|
|
14
14
|
def register_module_processor(cls, framework, processor_class):
|
|
15
15
|
cls._module_processor[framework] = processor_class
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
@classmethod
|
|
18
18
|
def get_module_processor(cls, framework):
|
|
19
19
|
processor_class = cls._module_processor.get(framework)
|
|
@@ -39,7 +39,7 @@ class DataProcessorFactory:
|
|
|
39
39
|
TensorDataProcessor as PytorchTensorDataProcessor,
|
|
40
40
|
OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
|
|
41
41
|
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
|
|
42
|
-
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
42
|
+
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
43
43
|
)
|
|
44
44
|
from ....pytorch.module_processer import ModuleProcesser
|
|
45
45
|
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
|
|
@@ -47,15 +47,13 @@ class DataProcessorFactory:
|
|
|
47
47
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
48
48
|
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
|
|
49
49
|
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
|
|
50
|
-
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
50
|
+
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
51
51
|
elif framework == Const.MS_FRAMEWORK:
|
|
52
52
|
from .mindspore_processor import (
|
|
53
53
|
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
|
|
54
54
|
TensorDataProcessor as MindsporeTensorDataProcessor,
|
|
55
|
-
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
|
|
56
|
-
FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor
|
|
55
|
+
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
|
|
57
56
|
)
|
|
58
57
|
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
|
|
59
58
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
60
59
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
61
|
-
cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor)
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import zlib
|
|
17
|
+
|
|
18
|
+
import mindspore as ms
|
|
19
|
+
from mindspore import ops
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
|
|
24
|
+
ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
|
|
25
|
+
from msprobe.core.common.file_check import path_len_exceeds_limit
|
|
26
|
+
from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions
|
|
27
|
+
from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
|
|
28
|
+
from msprobe.mindspore.common.log import logger
|
|
29
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MindsporeDataProcessor(BaseDataProcessor):
|
|
33
|
+
mindspore_special_type = tuple([ms.Tensor])
|
|
34
|
+
ops_func, mint_ops_func, _ = load_ops_functions()
|
|
35
|
+
|
|
36
|
+
def __init__(self, config, data_writer):
|
|
37
|
+
super().__init__(config, data_writer)
|
|
38
|
+
self.mindspore_object_key = {
|
|
39
|
+
"dtype": self.analyze_dtype_in_kwargs
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def get_md5_for_tensor(x):
|
|
44
|
+
x = convert_bf16_to_fp32(x)
|
|
45
|
+
tensor_bytes = x.asnumpy().tobytes()
|
|
46
|
+
crc32_hash = zlib.crc32(tensor_bytes)
|
|
47
|
+
return f"{crc32_hash:08x}"
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def analyze_dtype_in_kwargs(element):
|
|
51
|
+
return {"type": "mindspore.dtype", "value": str(element)}
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _analyze_builtin(arg):
|
|
55
|
+
single_arg = {}
|
|
56
|
+
if isinstance(arg, slice):
|
|
57
|
+
single_arg.update({"type": "slice"})
|
|
58
|
+
# slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
|
|
59
|
+
values = [
|
|
60
|
+
value if not isinstance(value, ms.Tensor) else value.item()
|
|
61
|
+
for value in [arg.start, arg.stop, arg.step]
|
|
62
|
+
]
|
|
63
|
+
single_arg.update({"value": values})
|
|
64
|
+
else:
|
|
65
|
+
single_arg.update({"type": type(arg).__name__})
|
|
66
|
+
single_arg.update({"value": arg})
|
|
67
|
+
return single_arg
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def get_special_types(cls):
|
|
71
|
+
return super().get_special_types() + cls.mindspore_special_type
|
|
72
|
+
|
|
73
|
+
def get_stat_info(self, data):
|
|
74
|
+
tensor_stat = TensorStatInfo()
|
|
75
|
+
if data.numel() == 0:
|
|
76
|
+
return tensor_stat
|
|
77
|
+
elif data.dtype == ms.bool_:
|
|
78
|
+
data_np = data.asnumpy()
|
|
79
|
+
tensor_stat.max = np.max(data_np).item()
|
|
80
|
+
tensor_stat.min = np.min(data_np).item()
|
|
81
|
+
elif not data.shape:
|
|
82
|
+
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
|
|
83
|
+
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
84
|
+
data_abs = np.abs(data.asnumpy())
|
|
85
|
+
tensor_stat.max = np.max(data_abs).item()
|
|
86
|
+
tensor_stat.min = np.min(data_abs).item()
|
|
87
|
+
tensor_stat.mean = np.mean(data_abs).item()
|
|
88
|
+
tensor_stat.norm = np.linalg.norm(data_abs).item()
|
|
89
|
+
else:
|
|
90
|
+
if data.dtype == ms.bfloat16 or not ops.is_floating_point(data):
|
|
91
|
+
data = data.to(ms.float32)
|
|
92
|
+
api_register.norm_inner_op_set_ori_func()
|
|
93
|
+
tensor_stat.max = self.mint_ops_func["max"](data).item()
|
|
94
|
+
tensor_stat.min = self.mint_ops_func["min"](data).item()
|
|
95
|
+
tensor_stat.mean = self.mint_ops_func["mean"](data).item()
|
|
96
|
+
tensor_stat.norm = self.ops_func["norm"](data).item()
|
|
97
|
+
api_register.norm_inner_op_set_hook_func()
|
|
98
|
+
return tensor_stat
|
|
99
|
+
|
|
100
|
+
def analyze_single_element(self, element, suffix_stack):
|
|
101
|
+
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
102
|
+
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
103
|
+
|
|
104
|
+
converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
|
|
105
|
+
if converted_numpy is not element:
|
|
106
|
+
return self._analyze_numpy(converted_numpy, numpy_type)
|
|
107
|
+
if isinstance(element, ms.Tensor):
|
|
108
|
+
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
|
|
109
|
+
|
|
110
|
+
if isinstance(element, (bool, int, float, str, slice)):
|
|
111
|
+
return self._analyze_builtin(element)
|
|
112
|
+
return {}
|
|
113
|
+
|
|
114
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
115
|
+
tensor_stat = self.get_stat_info(tensor)
|
|
116
|
+
tensor_json = {
|
|
117
|
+
'type': 'mindspore.Tensor',
|
|
118
|
+
'dtype': str(tensor.dtype),
|
|
119
|
+
'shape': tensor.shape,
|
|
120
|
+
'Max': self.transfer_type(tensor_stat.max),
|
|
121
|
+
'Min': self.transfer_type(tensor_stat.min),
|
|
122
|
+
'Mean': self.transfer_type(tensor_stat.mean),
|
|
123
|
+
'Norm': self.transfer_type(tensor_stat.norm),
|
|
124
|
+
}
|
|
125
|
+
if self.config.summary_mode == Const.MD5:
|
|
126
|
+
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
127
|
+
tensor_json.update({Const.MD5: tensor_md5})
|
|
128
|
+
return tensor_json
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TensorDataProcessor(MindsporeDataProcessor):
|
|
136
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
137
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
138
|
+
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
139
|
+
single_arg.update({"data_name": dump_data_name})
|
|
140
|
+
save_tensor_as_npy(tensor, file_path)
|
|
141
|
+
return single_arg
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
145
|
+
__slots__ = ["cached_tensors_and_file_paths"]
|
|
146
|
+
|
|
147
|
+
def __init__(self, config, data_writer):
|
|
148
|
+
super().__init__(config, data_writer)
|
|
149
|
+
self.cached_tensors_and_file_paths = {}
|
|
150
|
+
self.real_overflow_nums = 0
|
|
151
|
+
self.overflow_nums = config.overflow_nums
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def is_terminated(self):
|
|
155
|
+
if self.overflow_nums == -1:
|
|
156
|
+
return False
|
|
157
|
+
if self.real_overflow_nums >= self.overflow_nums:
|
|
158
|
+
logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
|
|
159
|
+
return True
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
163
|
+
self.has_overflow = False
|
|
164
|
+
api_info_struct = super().analyze_forward(name, module, module_input_output)
|
|
165
|
+
self.maybe_save_overflow_data()
|
|
166
|
+
return api_info_struct if self.has_overflow else None
|
|
167
|
+
|
|
168
|
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
169
|
+
self.has_overflow = False
|
|
170
|
+
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
171
|
+
self.maybe_save_overflow_data()
|
|
172
|
+
return api_info_struct if self.has_overflow else None
|
|
173
|
+
|
|
174
|
+
def maybe_save_overflow_data(self):
|
|
175
|
+
if self.has_overflow:
|
|
176
|
+
for file_path, tensor in self.cached_tensors_and_file_paths.items():
|
|
177
|
+
save_tensor_as_npy(tensor, file_path)
|
|
178
|
+
self.real_overflow_nums += 1
|
|
179
|
+
self.cached_tensors_and_file_paths = {}
|
|
180
|
+
|
|
181
|
+
def _analyze_maybe_overflow_tensor(self, tensor_json):
|
|
182
|
+
if tensor_json['Max'] is None:
|
|
183
|
+
return
|
|
184
|
+
if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
|
|
185
|
+
self.has_overflow = True
|
|
186
|
+
if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
|
|
187
|
+
self.has_overflow = True
|
|
188
|
+
|
|
189
|
+
def _analyze_tensor(self, tensor, suffix):
|
|
190
|
+
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
191
|
+
if not path_len_exceeds_limit(file_path):
|
|
192
|
+
self.cached_tensors_and_file_paths.update({file_path: tensor})
|
|
193
|
+
else:
|
|
194
|
+
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
195
|
+
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
196
|
+
self._analyze_maybe_overflow_tensor(single_arg)
|
|
197
|
+
single_arg.update({"data_name": dump_data_name})
|
|
198
|
+
return single_arg
|