mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
msprobe/mindspore/doc/dump.md
CHANGED
|
@@ -12,7 +12,7 @@ msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方
|
|
|
12
12
|
|
|
13
13
|
通过加载dump配置文件的方式来确定dump操作的详细配置。
|
|
14
14
|
|
|
15
|
-
可以在from msprobe.mindspore import PrecisionDebugger
|
|
15
|
+
PrecisionDebugger可以在from msprobe.mindspore import PrecisionDebugger之后的位置添加。详细使用可参考“**示例代码**”。
|
|
16
16
|
|
|
17
17
|
**原型**
|
|
18
18
|
|
|
@@ -24,7 +24,7 @@ PrecisionDebugger(config_path=None)
|
|
|
24
24
|
|
|
25
25
|
| 参数名 | 说明 | 是否必选 |
|
|
26
26
|
| ----------- | ------------------------------------------------------------ | -------- |
|
|
27
|
-
| config_path | 指定dump配置文件路径,String类型。参数示例:"./config.json"。未配置该路径时,默认使用[config.json](../../config)文件的默认配置。config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config/config.json)文件。 | 否 |
|
|
27
|
+
| config_path | 指定dump配置文件路径,String类型。参数示例:"./config.json"。未配置该路径时,默认使用[config.json](../../config)文件的默认配置。config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config/config.json)文件。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md)》。 | 否 |
|
|
28
28
|
|
|
29
29
|
### start函数
|
|
30
30
|
|
|
@@ -32,16 +32,64 @@ PrecisionDebugger(config_path=None)
|
|
|
32
32
|
|
|
33
33
|
启动函数。
|
|
34
34
|
|
|
35
|
+
在模型初始化之后的位置添加。需要与stop函数一起添加在for循环内。
|
|
36
|
+
|
|
35
37
|
**原型**
|
|
36
38
|
|
|
37
39
|
```Python
|
|
38
|
-
debugger.start()
|
|
40
|
+
debugger.start(model = None)
|
|
39
41
|
```
|
|
40
42
|
|
|
41
|
-
该函数为类函数,可以使用debugger.start()也可以使用PrecisionDebugger.start()
|
|
43
|
+
该函数为类函数,可以使用debugger.start(model = None)也可以使用PrecisionDebugger.start(model = None)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
**参数说明**
|
|
47
|
+
|
|
48
|
+
| 参数名 | 说明 | 是否必选 |
|
|
49
|
+
| ----------- |---------------------------------------------------------------------------------------| -------- |
|
|
50
|
+
| model | 指具体的mindspore.nn.Cell,默认未配置,L1级别下传入model可以使能对primitive op的dump,否则无法dump primitive op。 | 否 |
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
### stop函数
|
|
54
|
+
|
|
55
|
+
**功能说明**
|
|
56
|
+
|
|
57
|
+
dump停止函数。
|
|
58
|
+
|
|
59
|
+
在**start**函数之后的任意位置添加。需要与start函数一起添加在for循环内。若需要dump反向数据,则需要添加在反向计算代码之后。
|
|
60
|
+
|
|
61
|
+
仅MindSpore动态图场景支持。
|
|
62
|
+
|
|
63
|
+
**原型**
|
|
64
|
+
|
|
65
|
+
```Python
|
|
66
|
+
debugger.stop()
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
该函数为类函数,可以使用debugger.stop()也可以使用PrecisionDebugger.stop()。
|
|
70
|
+
|
|
71
|
+
### step函数
|
|
72
|
+
|
|
73
|
+
**功能说明**
|
|
74
|
+
|
|
75
|
+
结束标识。
|
|
76
|
+
|
|
77
|
+
在最后一个**stop**函数后或一个step结束的位置添加。
|
|
78
|
+
|
|
79
|
+
仅MindSpore动态图场景支持。
|
|
80
|
+
|
|
81
|
+
**原型**
|
|
82
|
+
|
|
83
|
+
```Python
|
|
84
|
+
debugger.step()
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
该函数为类函数,可以使用debugger.step()也可以使用PrecisionDebugger.step()。
|
|
42
88
|
|
|
43
89
|
## 示例代码
|
|
44
90
|
|
|
91
|
+
### MindSpore静态图场景
|
|
92
|
+
|
|
45
93
|
```Python
|
|
46
94
|
from msprobe.mindspore import PrecisionDebugger
|
|
47
95
|
debugger = PrecisionDebugger(config_path="./config.json")
|
|
@@ -51,15 +99,119 @@ debugger.start()
|
|
|
51
99
|
...
|
|
52
100
|
```
|
|
53
101
|
|
|
102
|
+
### MindSpore动态图场景
|
|
103
|
+
|
|
104
|
+
当使用模型使用for循环时,在每个迭代的开始插入debugger.start(),在每个迭代的结束插入debugger.stop()与debugger.step():
|
|
105
|
+
|
|
106
|
+
```Python
|
|
107
|
+
import mindspore as ms
|
|
108
|
+
from msprobe.mindspore import PrecisionDebugger
|
|
109
|
+
|
|
110
|
+
# 请勿将PrecisionDebugger的初始化插入到循环代码中
|
|
111
|
+
debugger = PrecisionDebugger(config_path="./config.json")
|
|
112
|
+
|
|
113
|
+
# 模型、损失函数的定义以及初始化等操作
|
|
114
|
+
# ...
|
|
115
|
+
|
|
116
|
+
# 数据集迭代的地方往往是模型开始训练的地方
|
|
117
|
+
for data, label in data_loader:
|
|
118
|
+
debugger.start() # 开启数据dump
|
|
119
|
+
net = Model()
|
|
120
|
+
# 如下是模型每个step执行的逻辑
|
|
121
|
+
grad_net = ms.grad(net)(data)
|
|
122
|
+
# ...
|
|
123
|
+
debugger.stop() # 关闭数据dump
|
|
124
|
+
debugger.step() # 结束一个step的dump
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
当使用模型的train方法而非for循环时,可以通过在callbacks参数中传入MsprobeStep(debugger):
|
|
128
|
+
|
|
129
|
+
```Python
|
|
130
|
+
from msprobe.mindspore.common.utils import MsprobeStep
|
|
131
|
+
from msprobe.mindspore import PrecisionDebugger
|
|
132
|
+
|
|
133
|
+
# 初始化PrecisionDebugger
|
|
134
|
+
debugger = PrecisionDebugger(config_path="./config.json")
|
|
135
|
+
|
|
136
|
+
# 自动在每个step开始时调用start(),在每个step结束时调用stop()和step()。
|
|
137
|
+
# 这意味着您无需手动在循环内添加start、stop和step函数,框架会自动完成数据的dump操作。
|
|
138
|
+
trainer.train(1, dataset_train, callbacks=[loss_monior, MsprobeStep(debugger)])
|
|
139
|
+
|
|
140
|
+
```
|
|
141
|
+
|
|
54
142
|
## dump结果文件介绍
|
|
55
143
|
|
|
144
|
+
### MindSpore静态图场景
|
|
145
|
+
|
|
56
146
|
训练结束后,工具将dump的数据保存在dump_path参数指定的目录下。
|
|
57
147
|
|
|
58
|
-
-
|
|
148
|
+
- jit_level为O0/O1时
|
|
59
149
|
|
|
60
150
|
dump结果目录请参见MindSpore官网中的《[同步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%90%8C%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
|
|
61
151
|
|
|
62
|
-
-
|
|
152
|
+
- jit_level为O2时
|
|
63
153
|
|
|
64
154
|
dump结果目录请参见MindSpore官网中的《[异步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%BC%82%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
|
|
65
155
|
|
|
156
|
+
jit_level请参见[mindspore.set_context](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.JitConfig.html#mindspore-jitconfig)配置jit_config。
|
|
157
|
+
|
|
158
|
+
### MindSpore动态图场景
|
|
159
|
+
|
|
160
|
+
训练结束后,工具将dump的数据保存在dump_path参数指定的目录下。
|
|
161
|
+
|
|
162
|
+
dump结果目录结构示例如下:
|
|
163
|
+
|
|
164
|
+
```bash
|
|
165
|
+
├── dump_path
|
|
166
|
+
│ ├── step0
|
|
167
|
+
│ | ├── rank0
|
|
168
|
+
│ | │ ├── dump_tensor_data
|
|
169
|
+
| | | | ├── MintFunctional.relu.0.backward.input.0.npy
|
|
170
|
+
| | | | ├── Mint.abs.0.forward.input.0.npy
|
|
171
|
+
| | | | ├── Functional.split.0.forward.input.0.npy
|
|
172
|
+
| | | | ├── Tensor.__add__.0.forward.output.0.npy
|
|
173
|
+
| | | | ...
|
|
174
|
+
| | | | └── Jit.AlexNet.0.forward.input.0.npy
|
|
175
|
+
│ | | ├── dump.json # 保存前反向算子、算子的统计量信息或溢出算子信息。包含dump数据的API名称(命名格式为:`{api_type}_{api_name}_{API调用次数}_{前向反向}_{input/output}.{参数序号}`)、dtype、 shape、各数据的max、min、mean、L2norm统计信息以及当配置summary_mode="md5"时的md5数据。其中,“参数序号”表示该API下的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个子参数;L2norm表示L2范数(平方根)
|
|
176
|
+
│ | | ├── stack.json # 算子调用栈信息
|
|
177
|
+
│ | | └── construct.json # 分层分级结构,level为L1时,construct.json内容为空
|
|
178
|
+
│ | ├── rank1
|
|
179
|
+
| | | ├── dump_tensor_data
|
|
180
|
+
| | | | └── ...
|
|
181
|
+
│ | | ├── dump.json
|
|
182
|
+
│ | | ├── stack.json
|
|
183
|
+
| | | └── construct.json
|
|
184
|
+
│ | ├── ...
|
|
185
|
+
│ | |
|
|
186
|
+
| | └── rank7
|
|
187
|
+
│ ├── step1
|
|
188
|
+
│ | ├── ...
|
|
189
|
+
│ ├── step2
|
|
190
|
+
```
|
|
191
|
+
|
|
192
|
+
dump过程中,npy文件在对应算子或者模块被执行后就会落盘,而json文件则需要在正常执行PrecisionDebugger.stop()后才会写入完整数据,异常的程序终止会保存终止前被执行算子的相关npy文件,可能会导致json文件中数据丢失。
|
|
193
|
+
|
|
194
|
+
其中rank为设备上各卡的ID,每张卡上dump的数据会生成对应dump目录。非分布式场景下没有rank ID,目录名称为rank。
|
|
195
|
+
|
|
196
|
+
动态图场景下使能PSJit或PIJit,装饰特定Cell或function,被装饰的部分会全部/部分使能静态图流程。PSJit场景下config.json文件配置level为L1时,被PSJit装饰的部分也作为API被dump到对应目录;若配置level为L2时,则只会dump用户网络中静态图流程下的相关kernel。PIJit场景开启dump工具后,会被还原为动态图,按API粒度进行dump。
|
|
197
|
+
|
|
198
|
+
npy文件保存的前缀和MindSpore对应关系如下:
|
|
199
|
+
|
|
200
|
+
| 前缀 | MindSpore模块 |
|
|
201
|
+
| -------------- | ---------------------------- |
|
|
202
|
+
| Tensor | mindspore.Tensor |
|
|
203
|
+
| Functional | mindspore.ops |
|
|
204
|
+
| Mint | mindspore.mint |
|
|
205
|
+
| MintFunctional | mindspore.mint.nn.functional |
|
|
206
|
+
| Jit | mindspore.jit |
|
|
207
|
+
|
|
208
|
+
## 工具支持的API列表
|
|
209
|
+
|
|
210
|
+
msprobe工具维护固定的API支持列表,若需要删除或增加dump的API,可以在msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml文件内手动修改,如下示例:
|
|
211
|
+
|
|
212
|
+
```bash
|
|
213
|
+
ops: # ops为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API
|
|
214
|
+
- adaptive_avg_pool1d
|
|
215
|
+
- adaptive_avg_pool2d
|
|
216
|
+
- adaptive_avg_pool3d
|
|
217
|
+
```
|
|
@@ -1,24 +1,25 @@
|
|
|
1
|
+
from msprobe.mindspore.common.const import Const
|
|
1
2
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
2
|
-
from msprobe.mindspore.dump.
|
|
3
|
+
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
3
4
|
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class DumpToolFactory:
|
|
7
8
|
tools = {
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
Const.CELL: {
|
|
10
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
11
|
+
Const.GRAPH_GE_MODE: None,
|
|
12
|
+
Const.PYNATIVE_MODE: None
|
|
12
13
|
},
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
Const.API: {
|
|
15
|
+
Const.GRAPH_KBYK_MODE: None,
|
|
16
|
+
Const.GRAPH_GE_MODE: None,
|
|
17
|
+
Const.PYNATIVE_MODE: None
|
|
17
18
|
},
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
Const.KERNEL: {
|
|
20
|
+
Const.GRAPH_KBYK_MODE: KernelKbykDump,
|
|
21
|
+
Const.GRAPH_GE_MODE: KernelGraphDump,
|
|
22
|
+
Const.PYNATIVE_MODE: KernelKbykDump
|
|
22
23
|
}
|
|
23
24
|
}
|
|
24
25
|
|
|
@@ -26,13 +27,9 @@ class DumpToolFactory:
|
|
|
26
27
|
def create(config: DebuggerConfig):
|
|
27
28
|
tool = DumpToolFactory.tools.get(config.level)
|
|
28
29
|
if not tool:
|
|
29
|
-
raise Exception("
|
|
30
|
-
|
|
31
|
-
tool = tool.get("kbk")
|
|
32
|
-
elif config.level == "kernel":
|
|
33
|
-
tool = tool.get("graph")
|
|
34
|
-
elif config.level == "cell":
|
|
35
|
-
raise Exception("Cell dump in not supported now.")
|
|
30
|
+
raise Exception("Valid level is needed.")
|
|
31
|
+
tool = tool.get(config.execution_mode)
|
|
36
32
|
if not tool:
|
|
37
|
-
raise Exception("Data dump
|
|
38
|
-
|
|
33
|
+
raise Exception(f"Data dump is not supported in {config.execution_mode} mode "
|
|
34
|
+
f"when dump level is {config.level}.")
|
|
35
|
+
return tool(config)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import mindspore as ms
|
|
17
|
+
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
|
|
18
|
+
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
|
|
19
|
+
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
|
|
20
|
+
from msprobe.core.common.utils import Const
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ApiRegistry:
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.tensor_ori_attr = {}
|
|
26
|
+
self.functional_ori_attr = {}
|
|
27
|
+
self.mint_ops_ori_attr = {}
|
|
28
|
+
self.mint_func_ops_ori_attr = {}
|
|
29
|
+
self.norm_inner_ops_ori_attr = {}
|
|
30
|
+
|
|
31
|
+
self.tensor_hook_attr = {}
|
|
32
|
+
self.functional_hook_attr = {}
|
|
33
|
+
self.mint_ops_hook_attr = {}
|
|
34
|
+
self.mint_func_ops_hook_attr = {}
|
|
35
|
+
self.norm_inner_ops_hook_attr = {}
|
|
36
|
+
|
|
37
|
+
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
41
|
+
for api in api_list:
|
|
42
|
+
if Const.SEP in api:
|
|
43
|
+
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
44
|
+
sub_module = getattr(ori_api_group, sub_module_name)
|
|
45
|
+
api_ori_attr[api] = getattr(sub_module, sub_op)
|
|
46
|
+
else:
|
|
47
|
+
api_ori_attr[api] = getattr(ori_api_group, api)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def set_api_attr(api_group, attr_dict):
|
|
51
|
+
for api, api_attr in attr_dict.items():
|
|
52
|
+
if Const.SEP in api:
|
|
53
|
+
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
54
|
+
sub_module = getattr(api_group, sub_module_name, None)
|
|
55
|
+
if sub_module is not None:
|
|
56
|
+
setattr(sub_module, sub_op, api_attr)
|
|
57
|
+
else:
|
|
58
|
+
setattr(api_group, api, api_attr)
|
|
59
|
+
|
|
60
|
+
def norm_inner_op_set_hook_func(self):
|
|
61
|
+
self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr)
|
|
62
|
+
|
|
63
|
+
def norm_inner_op_set_ori_func(self):
|
|
64
|
+
self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr)
|
|
65
|
+
|
|
66
|
+
def api_set_hook_func(self):
|
|
67
|
+
self.set_api_attr(ms.Tensor, self.tensor_hook_attr)
|
|
68
|
+
self.set_api_attr(ms.ops, self.functional_hook_attr)
|
|
69
|
+
self.set_api_attr(ms.mint, self.mint_ops_hook_attr)
|
|
70
|
+
self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr)
|
|
71
|
+
|
|
72
|
+
def api_set_ori_func(self):
|
|
73
|
+
self.set_api_attr(ms.Tensor, self.tensor_ori_attr)
|
|
74
|
+
self.set_api_attr(ms.ops, self.functional_ori_attr)
|
|
75
|
+
self.set_api_attr(ms.mint, self.mint_ops_ori_attr)
|
|
76
|
+
self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr)
|
|
77
|
+
|
|
78
|
+
def initialize_hook(self, hook):
|
|
79
|
+
self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr)
|
|
80
|
+
wrap_tensor_ops_and_bind(hook)
|
|
81
|
+
for attr_name in dir(HOOKTensor):
|
|
82
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
83
|
+
self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name)
|
|
84
|
+
|
|
85
|
+
functional_ops, mint_ops, mint_func_ops = get_functional_ops()
|
|
86
|
+
self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
87
|
+
self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr)
|
|
88
|
+
self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr)
|
|
89
|
+
self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr)
|
|
90
|
+
setup_hooks(hook)
|
|
91
|
+
for attr_name in dir(HOOKFunctionalOP):
|
|
92
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
93
|
+
self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
|
|
94
|
+
if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops:
|
|
95
|
+
self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
|
|
96
|
+
for attr_name in dir(HOOKMintOP):
|
|
97
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
98
|
+
self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name)
|
|
99
|
+
for attr_name in dir(HOOKMintNNFunctionalOP):
|
|
100
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
101
|
+
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
api_register = ApiRegistry()
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
|
|
17
|
+
from mindspore import nn
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HOOKCell(nn.Cell):
|
|
22
|
+
cell_count = defaultdict(int)
|
|
23
|
+
g_stop_hook = False
|
|
24
|
+
|
|
25
|
+
def __init__(self, build_hook) -> None:
|
|
26
|
+
super(HOOKCell, self).__init__()
|
|
27
|
+
self.changed_status = False
|
|
28
|
+
self.input_kwargs = {}
|
|
29
|
+
self.prefix = ""
|
|
30
|
+
if not HOOKCell.g_stop_hook:
|
|
31
|
+
HOOKCell.g_stop_hook = True
|
|
32
|
+
self.changed_status = True
|
|
33
|
+
if hasattr(self, "prefix_op_name_"):
|
|
34
|
+
self.prefix = self.prefix_op_name_
|
|
35
|
+
|
|
36
|
+
HOOKCell.cell_count[self.prefix] += 1
|
|
37
|
+
self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
|
|
38
|
+
forward_hook, backward_hook = build_hook(self.prefix)
|
|
39
|
+
self.register_forward_hook(forward_hook)
|
|
40
|
+
self.register_backward_hook(backward_hook)
|
|
41
|
+
|
|
42
|
+
# 重载call,加全局标志。
|
|
43
|
+
def __call__(self, *args, **kwargs):
|
|
44
|
+
try:
|
|
45
|
+
self.input_kwargs = kwargs
|
|
46
|
+
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
raise e
|
|
49
|
+
finally:
|
|
50
|
+
if self.changed_status:
|
|
51
|
+
self.changed_status = False
|
|
52
|
+
HOOKCell.g_stop_hook = False
|
|
53
|
+
return out
|