mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
msprobe/mindspore/ms_config.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
|
+
|
|
2
3
|
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
3
4
|
from msprobe.core.common.file_check import FileOpen
|
|
5
|
+
from msprobe.core.common.const import Const
|
|
6
|
+
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
7
|
+
from msprobe.mindspore.common.log import logger
|
|
8
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
9
|
+
from msprobe.core.grad_probe.utils import check_numeral_list_ascend
|
|
4
10
|
|
|
5
11
|
|
|
6
12
|
class TensorConfig(BaseConfig):
|
|
@@ -31,39 +37,81 @@ class StatisticsConfig(BaseConfig):
|
|
|
31
37
|
if self.data_mode is not None and len(self.data_mode) > 0:
|
|
32
38
|
if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
|
|
33
39
|
raise Exception("data_mode must be all, input or output")
|
|
40
|
+
if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
|
|
41
|
+
raise Exception("summary_mode is invalid")
|
|
34
42
|
|
|
35
43
|
|
|
36
|
-
class
|
|
44
|
+
class OverflowCheckConfig(BaseConfig):
|
|
37
45
|
def __init__(self, json_config):
|
|
38
46
|
super().__init__(json_config)
|
|
39
|
-
self.
|
|
40
|
-
self.check_mode = json_config.get("check_mode")
|
|
47
|
+
self.data_mode = ["all"]
|
|
41
48
|
self._check_config()
|
|
42
49
|
|
|
43
50
|
def _check_config(self):
|
|
44
|
-
if self.
|
|
45
|
-
|
|
46
|
-
|
|
51
|
+
if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
|
|
52
|
+
raise Exception("overflow_nums is invalid, it should be an integer")
|
|
53
|
+
if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
|
|
54
|
+
raise Exception("overflow_nums should be -1 or positive integer")
|
|
47
55
|
if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
48
56
|
raise Exception("check_mode is invalid")
|
|
49
57
|
|
|
50
58
|
|
|
59
|
+
class FreeBenchmarkConfig(BaseConfig):
|
|
60
|
+
def __init__(self, task_config):
|
|
61
|
+
super().__init__(task_config)
|
|
62
|
+
self._check_config()
|
|
63
|
+
|
|
64
|
+
def _check_config(self):
|
|
65
|
+
if self.fuzz_device and self.fuzz_device not in FreeBenchmarkConst.DEVICE_LIST:
|
|
66
|
+
raise Exception("fuzz_device must be npu or empty")
|
|
67
|
+
if self.pert_mode and self.pert_mode not in FreeBenchmarkConst.PERT_TYPE_LIST:
|
|
68
|
+
raise Exception("pert_mode must be improve_precision, add_noise, bit_noise, no_change or empty")
|
|
69
|
+
if self.handler_type and self.handler_type not in FreeBenchmarkConst.HANDLER_TYPE_LIST:
|
|
70
|
+
raise Exception("handler_type must be check, fix or empty")
|
|
71
|
+
if self.fuzz_level and self.fuzz_level not in FreeBenchmarkConst.DUMP_LEVEL_LIST:
|
|
72
|
+
raise Exception("fuzz_level must be L1 or empty")
|
|
73
|
+
if self.fuzz_stage and self.fuzz_stage not in FreeBenchmarkConst.STAGE_LIST:
|
|
74
|
+
raise Exception("fuzz_stage must be forward or empty")
|
|
75
|
+
if self.if_preheat or self.preheat_step or self.max_sample:
|
|
76
|
+
logger.warning("'if_preheat', 'preheat_step' and 'max_sample' settings "
|
|
77
|
+
"are not supported for mindspore free benchmark task.")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GradProbeConfig(BaseConfig):
|
|
81
|
+
def __init__(self, json_config):
|
|
82
|
+
super().__init__(json_config)
|
|
83
|
+
self.grad_level = json_config.get("grad_level", "L1")
|
|
84
|
+
self.param_list = json_config.get("param_list", [])
|
|
85
|
+
self.bounds = json_config.get("bounds", [])
|
|
86
|
+
|
|
87
|
+
def _check_config(self):
|
|
88
|
+
if self.grad_level not in level_adp.keys():
|
|
89
|
+
raise Exception(f"grad_level must be one of {level_adp.keys()}")
|
|
90
|
+
if not isinstance(self.param_list, list):
|
|
91
|
+
raise Exception(f"param_list must be a list")
|
|
92
|
+
check_numeral_list_ascend(self.bounds)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
TaskDict = {
|
|
96
|
+
Const.TENSOR: TensorConfig,
|
|
97
|
+
Const.STATISTICS: StatisticsConfig,
|
|
98
|
+
Const.OVERFLOW_CHECK: OverflowCheckConfig,
|
|
99
|
+
Const.FREE_BENCHMARK: FreeBenchmarkConfig,
|
|
100
|
+
Const.GRAD_PROBE: GradProbeConfig,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
51
104
|
def parse_common_config(json_config):
|
|
52
105
|
return CommonConfig(json_config)
|
|
53
106
|
|
|
54
107
|
|
|
55
108
|
def parse_task_config(task, json_config):
|
|
56
|
-
task_map = json_config
|
|
109
|
+
task_map = json_config.get(task)
|
|
57
110
|
if not task_map:
|
|
58
111
|
task_map = dict()
|
|
59
|
-
if task
|
|
60
|
-
return TensorConfig(task_map)
|
|
61
|
-
elif task == "statistics":
|
|
62
|
-
return StatisticsConfig(task_map)
|
|
63
|
-
elif task == "overflow_check":
|
|
64
|
-
return OverflowCheck(task_map)
|
|
65
|
-
else:
|
|
112
|
+
if task not in TaskDict:
|
|
66
113
|
raise Exception("task is invalid.")
|
|
114
|
+
return TaskDict.get(task)(task_map)
|
|
67
115
|
|
|
68
116
|
|
|
69
117
|
def parse_json_config(json_file_path):
|
|
@@ -73,6 +121,6 @@ def parse_json_config(json_file_path):
|
|
|
73
121
|
json_config = json.load(file)
|
|
74
122
|
common_config = parse_common_config(json_config)
|
|
75
123
|
if not common_config.task:
|
|
76
|
-
common_config.task =
|
|
124
|
+
common_config.task = Const.STATISTICS
|
|
77
125
|
task_config = parse_task_config(common_config.task, json_config)
|
|
78
126
|
return common_config, task_config
|
|
@@ -1,23 +1,24 @@
|
|
|
1
|
+
from msprobe.mindspore.common.const import Const
|
|
1
2
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
3
|
from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class OverflowCheckToolFactory:
|
|
6
7
|
tools = {
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
8
|
+
Const.CELL: {
|
|
9
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
10
|
+
Const.GRAPH_GE_MODE: None,
|
|
11
|
+
Const.PYNATIVE_MODE: None
|
|
11
12
|
},
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
13
|
+
Const.API: {
|
|
14
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
15
|
+
Const.GRAPH_GE_MODE: None,
|
|
16
|
+
Const.PYNATIVE_MODE: None
|
|
16
17
|
},
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
Const.KERNEL: {
|
|
19
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
20
|
+
Const.GRAPH_GE_MODE: KernelGraphOverflowCheck,
|
|
21
|
+
Const.PYNATIVE_MODE: None
|
|
21
22
|
}
|
|
22
23
|
}
|
|
23
24
|
|
|
@@ -25,8 +26,9 @@ class OverflowCheckToolFactory:
|
|
|
25
26
|
def create(config: DebuggerConfig):
|
|
26
27
|
tool = OverflowCheckToolFactory.tools.get(config.level)
|
|
27
28
|
if not tool:
|
|
28
|
-
raise Exception("
|
|
29
|
-
tool = tool.get(
|
|
29
|
+
raise Exception("Valid level is needed.")
|
|
30
|
+
tool = tool.get(config.execution_mode)
|
|
30
31
|
if not tool:
|
|
31
|
-
raise Exception("Overflow check
|
|
32
|
+
raise Exception(f"Overflow check is not supported in {config.execution_mode} mode "
|
|
33
|
+
f"when level is {config.level}.")
|
|
32
34
|
return tool(config)
|
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import copy
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
import functools
|
|
20
|
+
from collections import defaultdict
|
|
21
|
+
|
|
22
|
+
import mindspore as ms
|
|
23
|
+
from mindspore.common.tensor import Tensor
|
|
24
|
+
from mindspore import ops
|
|
25
|
+
from mindspore import nn
|
|
26
|
+
try:
|
|
27
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
28
|
+
pijit_label = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
pijit_label = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
34
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
35
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
36
|
+
from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create
|
|
37
|
+
from msprobe.mindspore.common.log import logger
|
|
38
|
+
from msprobe.core.common.utils import Const
|
|
39
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
40
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
41
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
|
|
42
|
+
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
43
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
44
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
45
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
46
|
+
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Service:
|
|
50
|
+
def __init__(self, config):
|
|
51
|
+
self.model = None
|
|
52
|
+
self.config = copy.deepcopy(config)
|
|
53
|
+
self.config.level = self.config.level_ori
|
|
54
|
+
self.data_collector = build_data_collector(self.config)
|
|
55
|
+
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
56
|
+
self.switch = False
|
|
57
|
+
self.current_iter = 0
|
|
58
|
+
self.first_start = True
|
|
59
|
+
self.current_rank = None
|
|
60
|
+
self.primitive_counters = {}
|
|
61
|
+
self.dump_iter_dir = None
|
|
62
|
+
self.start_call = False
|
|
63
|
+
self.check_level_valid()
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def check_model_valid(model):
|
|
67
|
+
if not model or isinstance(model, nn.Cell):
|
|
68
|
+
return model
|
|
69
|
+
raise MsprobeException(
|
|
70
|
+
MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def check_level_valid(self):
|
|
74
|
+
if self.config.level == "L2":
|
|
75
|
+
raise MsprobeException(
|
|
76
|
+
MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def build_hook(self, target_type, name):
|
|
80
|
+
def forward_hook(api_or_cell_name, cell, input, output):
|
|
81
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
82
|
+
api_or_cell_name = cell.mindstudio_reserved_name
|
|
83
|
+
self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
|
|
84
|
+
if not self.switch:
|
|
85
|
+
return None
|
|
86
|
+
if self.data_collector:
|
|
87
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
88
|
+
module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
|
|
89
|
+
else:
|
|
90
|
+
module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs, output=output)
|
|
91
|
+
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
92
|
+
if self.data_collector.if_return_forward_new_output():
|
|
93
|
+
return self.data_collector.get_forward_new_output()
|
|
94
|
+
if target_type == BaseScope.Module_Type_API:
|
|
95
|
+
del cell.input_kwargs
|
|
96
|
+
return output
|
|
97
|
+
|
|
98
|
+
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
99
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
100
|
+
api_or_cell_name = cell.mindstudio_reserved_name
|
|
101
|
+
self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
|
|
102
|
+
if not self.switch:
|
|
103
|
+
return
|
|
104
|
+
if self.data_collector:
|
|
105
|
+
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
106
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
107
|
+
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
108
|
+
|
|
109
|
+
pid = os.getpid()
|
|
110
|
+
forward_name_template = name + Const.FORWARD
|
|
111
|
+
backward_name_template = name + Const.BACKWARD
|
|
112
|
+
forward_hook = functools.partial(forward_hook, forward_name_template)
|
|
113
|
+
backward_hook = functools.partial(backward_hook, backward_name_template)
|
|
114
|
+
|
|
115
|
+
def wrap_forward_hook(cell, input, output):
|
|
116
|
+
return forward_hook(cell, input, output)
|
|
117
|
+
|
|
118
|
+
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
119
|
+
return backward_hook(cell, grad_input, grad_output)
|
|
120
|
+
|
|
121
|
+
return wrap_forward_hook, wrap_backward_hook
|
|
122
|
+
|
|
123
|
+
def wrap_primitive(self, origin_func, primitive_name):
|
|
124
|
+
service_instance = self
|
|
125
|
+
|
|
126
|
+
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
|
|
127
|
+
def backward_hook(grad):
|
|
128
|
+
captured_grads.append(grad)
|
|
129
|
+
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
|
|
130
|
+
try:
|
|
131
|
+
if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
|
|
132
|
+
service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
|
|
133
|
+
new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
|
|
134
|
+
service_instance.data_collector.backward_output_data_collect(
|
|
135
|
+
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
|
|
136
|
+
)
|
|
137
|
+
captured_grads.clear()
|
|
138
|
+
elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
|
|
139
|
+
service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
|
|
140
|
+
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
|
|
141
|
+
service_instance.data_collector.backward_input_data_collect(
|
|
142
|
+
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
|
|
143
|
+
)
|
|
144
|
+
captured_grads.clear()
|
|
145
|
+
|
|
146
|
+
except Exception as exception:
|
|
147
|
+
raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
|
|
148
|
+
f" updated_primitive_name: {updated_primitive_name}") from exception
|
|
149
|
+
|
|
150
|
+
return backward_hook
|
|
151
|
+
|
|
152
|
+
def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
|
|
153
|
+
hooked_inputs = []
|
|
154
|
+
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
|
|
155
|
+
input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
|
|
156
|
+
Const.INPUT)
|
|
157
|
+
for _, arg in enumerate(args):
|
|
158
|
+
if isinstance(arg, Tensor):
|
|
159
|
+
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
|
|
160
|
+
hooked_inputs.append(arg_hooked)
|
|
161
|
+
else:
|
|
162
|
+
hooked_inputs.append(arg)
|
|
163
|
+
return hooked_inputs
|
|
164
|
+
|
|
165
|
+
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
|
|
166
|
+
if isinstance(out, tuple):
|
|
167
|
+
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
|
|
168
|
+
else:
|
|
169
|
+
num_output_tensors = 1
|
|
170
|
+
output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
|
|
171
|
+
updated_primitive_name, Const.OUTPUT)
|
|
172
|
+
|
|
173
|
+
if isinstance(out, Tensor):
|
|
174
|
+
return ops.HookBackward(output_backward_hook)(out)
|
|
175
|
+
elif isinstance(out, tuple):
|
|
176
|
+
hooked_outputs = []
|
|
177
|
+
for tensor in out:
|
|
178
|
+
if isinstance(tensor, Tensor):
|
|
179
|
+
hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
|
|
180
|
+
else:
|
|
181
|
+
hooked_outputs.append(tensor)
|
|
182
|
+
return tuple(hooked_outputs)
|
|
183
|
+
return out
|
|
184
|
+
|
|
185
|
+
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
186
|
+
service_instance.update_primitive_counters(primitive_name)
|
|
187
|
+
current_count = service_instance.primitive_counters.get(primitive_name, 0)
|
|
188
|
+
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
|
|
189
|
+
|
|
190
|
+
if not service_instance.switch:
|
|
191
|
+
return origin_func(*args, **kwargs)
|
|
192
|
+
|
|
193
|
+
captured_grads_input, captured_grads_output = [], []
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
|
|
197
|
+
except Exception as exception:
|
|
198
|
+
raise Exception("This is a primitive op dump error during input hooking: {},"
|
|
199
|
+
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
out = origin_func(*hooked_inputs, **kwargs)
|
|
203
|
+
except Exception as exception:
|
|
204
|
+
raise Exception("This is a primitive op dump error during function call: {},"
|
|
205
|
+
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
206
|
+
|
|
207
|
+
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
|
|
208
|
+
service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
|
|
209
|
+
if service_instance.data_collector:
|
|
210
|
+
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
211
|
+
try:
|
|
212
|
+
service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
213
|
+
os.getpid(), module_input_output)
|
|
214
|
+
except Exception as exception:
|
|
215
|
+
raise Exception("This is a primitive op dump error during forward data collection: {},"
|
|
216
|
+
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
217
|
+
|
|
218
|
+
if service_instance.data_collector.if_return_forward_new_output():
|
|
219
|
+
out = service_instance.data_collector.get_forward_new_output()
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
223
|
+
except Exception as exception:
|
|
224
|
+
raise Exception("This is a primitive op dump error during output hooking: {},"
|
|
225
|
+
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
226
|
+
|
|
227
|
+
return out
|
|
228
|
+
|
|
229
|
+
return wrapped_primitive_call
|
|
230
|
+
|
|
231
|
+
def update_primitive_counters(self, primitive_name):
|
|
232
|
+
if primitive_name not in self.primitive_counters:
|
|
233
|
+
self.primitive_counters[primitive_name] = 0
|
|
234
|
+
else:
|
|
235
|
+
self.primitive_counters[primitive_name] += 1
|
|
236
|
+
|
|
237
|
+
def register_hooks(self):
|
|
238
|
+
primitive_set = set()
|
|
239
|
+
for _, cell in self.model.cells_and_names():
|
|
240
|
+
for pname, primitive in cell._primitives.items():
|
|
241
|
+
primitive_set.add((pname, primitive))
|
|
242
|
+
|
|
243
|
+
for pname, primitive in primitive_set:
|
|
244
|
+
NewPrimitive = type('NewPrimitive', (primitive.__class__,),
|
|
245
|
+
{'__call__': self.wrap_primitive(primitive.__call__, pname)})
|
|
246
|
+
primitive.__class__ = NewPrimitive
|
|
247
|
+
|
|
248
|
+
def step(self):
|
|
249
|
+
self.current_iter += 1
|
|
250
|
+
self.data_collector.update_iter(self.current_iter)
|
|
251
|
+
HOOKCell.cell_count = defaultdict(int)
|
|
252
|
+
CellProcessor.cell_count = {}
|
|
253
|
+
self.primitive_counters.clear()
|
|
254
|
+
|
|
255
|
+
def start(self, model=None):
|
|
256
|
+
self.model = self.check_model_valid(model)
|
|
257
|
+
self.start_call = True
|
|
258
|
+
logger.info("msprobe: debugger.start() is set successfully")
|
|
259
|
+
if self.config.step and self.current_iter > max(self.config.step):
|
|
260
|
+
self.stop()
|
|
261
|
+
raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
|
|
262
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
263
|
+
return
|
|
264
|
+
if self.first_start:
|
|
265
|
+
try:
|
|
266
|
+
self.current_rank = get_rank_if_initialized()
|
|
267
|
+
except DistributedNotInitializedError:
|
|
268
|
+
self.current_rank = None
|
|
269
|
+
|
|
270
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
271
|
+
return
|
|
272
|
+
self.register_hook_new()
|
|
273
|
+
self.first_start = False
|
|
274
|
+
self.switch = True
|
|
275
|
+
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
276
|
+
self.create_dirs()
|
|
277
|
+
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
278
|
+
if self.config.level == "L1":
|
|
279
|
+
JitDump.set_config(self.config)
|
|
280
|
+
JitDump.set_data_collector(self.data_collector)
|
|
281
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
282
|
+
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
283
|
+
if pijit_label:
|
|
284
|
+
PIJitCaptureContext.__enter__ = self.empty
|
|
285
|
+
PIJitCaptureContext.__exit__ = self.empty
|
|
286
|
+
|
|
287
|
+
def stop(self):
|
|
288
|
+
logger.info("msprobe: debugger.stop() is set successfully. "
|
|
289
|
+
"Please set debugger.start() to turn on the dump switch again. ")
|
|
290
|
+
if not self.start_call:
|
|
291
|
+
logger.error("msprobe: debugger.start() is not set in the current scope.")
|
|
292
|
+
raise Exception("debugger.start() is not set in the current scope.")
|
|
293
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
294
|
+
return
|
|
295
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
296
|
+
return
|
|
297
|
+
self.switch = False
|
|
298
|
+
self.start_call = False
|
|
299
|
+
self.data_collector.write_json()
|
|
300
|
+
|
|
301
|
+
def create_dirs(self):
|
|
302
|
+
check_path_before_create(self.config.dump_path)
|
|
303
|
+
if not os.path.exists(self.config.dump_path):
|
|
304
|
+
Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
|
|
305
|
+
file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
|
|
306
|
+
file_check.common_check()
|
|
307
|
+
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
308
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
309
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
310
|
+
if not os.path.exists(dump_dir):
|
|
311
|
+
Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
|
|
312
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
313
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
314
|
+
Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
|
|
315
|
+
else:
|
|
316
|
+
dump_data_dir = None
|
|
317
|
+
|
|
318
|
+
dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
319
|
+
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
320
|
+
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
321
|
+
self.data_collector.update_dump_paths(
|
|
322
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
|
|
323
|
+
|
|
324
|
+
def empty(self, *args, **kwargs):
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
def register_hook_new(self):
|
|
328
|
+
logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
|
|
329
|
+
if self.config.level == "L1":
|
|
330
|
+
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
331
|
+
api_register.api_set_hook_func()
|
|
332
|
+
if self.model:
|
|
333
|
+
self.register_hooks()
|
|
334
|
+
|
|
335
|
+
if self.config.level == "L0":
|
|
336
|
+
if not self.model:
|
|
337
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, "The current level is L0, the model cannot be None")
|
|
338
|
+
for name, cell in self.model.cells_and_names():
|
|
339
|
+
if cell == self.model:
|
|
340
|
+
continue
|
|
341
|
+
prefix = 'Cell' + Const.SEP + name + Const.SEP + \
|
|
342
|
+
cell.__class__.__name__ + Const.SEP
|
|
343
|
+
forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
344
|
+
cell.register_forward_hook(forward_hook)
|
|
345
|
+
cell.register_backward_hook(backward_hook)
|
|
346
|
+
|
|
347
|
+
cell.register_forward_pre_hook(
|
|
348
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
349
|
+
cell.register_forward_hook(
|
|
350
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
351
|
+
cell.register_backward_pre_hook(
|
|
352
|
+
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
353
|
+
cell.register_backward_hook(
|
|
354
|
+
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
@@ -1,20 +1,23 @@
|
|
|
1
|
+
from msprobe.core.common.const import Const
|
|
1
2
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
3
|
from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
|
|
3
4
|
from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
|
|
5
|
+
from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class TaskHandlerFactory:
|
|
7
9
|
tasks = {
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
10
|
+
Const.TENSOR: DumpToolFactory,
|
|
11
|
+
Const.STATISTICS: DumpToolFactory,
|
|
12
|
+
Const.OVERFLOW_CHECK: OverflowCheckToolFactory,
|
|
13
|
+
Const.FREE_BENCHMARK: SelfCheckToolFactory
|
|
11
14
|
}
|
|
12
15
|
|
|
13
16
|
@staticmethod
|
|
14
17
|
def create(config: DebuggerConfig):
|
|
15
18
|
task = TaskHandlerFactory.tasks.get(config.task)
|
|
16
19
|
if not task:
|
|
17
|
-
raise Exception("
|
|
20
|
+
raise Exception("Valid task is needed.")
|
|
18
21
|
handler = task.create(config)
|
|
19
22
|
if not handler:
|
|
20
23
|
raise Exception("Can not find task handler")
|