mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,75 +1,91 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
|
|
17
|
+
from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
|
|
18
|
+
npu_confusion_transpose_backward
|
|
19
|
+
from msprobe.pytorch.bench_functions.fast_gelu import npu_fast_gelu, npu_fast_gelu_backward
|
|
20
|
+
from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
|
|
21
|
+
from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
|
|
22
|
+
from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
|
|
23
|
+
from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad, \
|
|
24
|
+
gpu_fusion_attention
|
|
25
|
+
from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
|
|
26
|
+
from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
|
|
27
|
+
from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
|
|
28
|
+
npu_scaled_masked_softmax_backward
|
|
29
|
+
from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
|
|
30
|
+
from msprobe.pytorch.common.utils import logger
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Register(dict):
|
|
34
|
+
def __init__(self, *args, **kwargs):
|
|
35
|
+
super(Register, self).__init__(*args, **kwargs)
|
|
36
|
+
self._dict = {}
|
|
37
|
+
|
|
38
|
+
def __call__(self, target_func_list):
|
|
39
|
+
for target in target_func_list:
|
|
40
|
+
self.register(target)
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
def __setitem__(self, key, value):
|
|
44
|
+
self._dict[key] = value
|
|
45
|
+
|
|
46
|
+
def __getitem__(self, key):
|
|
47
|
+
return self._dict[key]
|
|
48
|
+
|
|
49
|
+
def __contains__(self, key):
|
|
50
|
+
return key in self._dict
|
|
51
|
+
|
|
52
|
+
def __str__(self):
|
|
53
|
+
return str(self._dict)
|
|
54
|
+
|
|
55
|
+
def keys(self):
|
|
56
|
+
return self._dict.keys()
|
|
57
|
+
|
|
58
|
+
def values(self):
|
|
59
|
+
return self._dict.values()
|
|
60
|
+
|
|
61
|
+
def items(self):
|
|
62
|
+
return self._dict.items()
|
|
63
|
+
|
|
64
|
+
def register(self, target):
|
|
65
|
+
|
|
66
|
+
def add_register_item(key, value):
|
|
67
|
+
if key in self._dict:
|
|
68
|
+
logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
|
|
69
|
+
self[key] = value
|
|
70
|
+
return value
|
|
71
|
+
|
|
72
|
+
if callable(target):
|
|
73
|
+
return add_register_item(target.__name__, target)
|
|
74
|
+
else:
|
|
75
|
+
raise Exception(f"The func {target} is not callable.")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# register for npu custom bench functions
|
|
79
|
+
npu_custom_functions = Register()
|
|
80
|
+
npu_custom_functions([
|
|
81
|
+
npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
|
|
82
|
+
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
|
|
83
|
+
])
|
|
84
|
+
|
|
85
|
+
# register for npu custom backward bench functions
|
|
86
|
+
npu_custom_grad_functions = Register()
|
|
87
|
+
npu_custom_grad_functions([
|
|
88
|
+
npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
|
|
89
|
+
npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
|
|
90
|
+
npu_swiglu_backward
|
|
91
|
+
])
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
|
|
23
|
+
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
24
|
+
from msprobe.pytorch.service import torch_version_above_or_equal_2
|
|
25
|
+
|
|
26
|
+
hook_handle_list = []
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def module_dump(module, dump_name):
|
|
30
|
+
if not isinstance(module, nn.Module):
|
|
31
|
+
logger.error("The parameter module in module_dump must be a Module subclass.")
|
|
32
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
33
|
+
if not isinstance(dump_name, str):
|
|
34
|
+
logger.error("The parameter dump_name in module_dump must be a str type.")
|
|
35
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
36
|
+
|
|
37
|
+
api_register.api_originality()
|
|
38
|
+
register_hook(module, dump_name)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def module_dump_end():
|
|
42
|
+
api_register.api_modularity()
|
|
43
|
+
remove_hook()
|
|
44
|
+
hook_handle_list.clear()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def register_hook(module, dump_name):
|
|
48
|
+
prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP
|
|
49
|
+
|
|
50
|
+
pdg = PrecisionDebugger()
|
|
51
|
+
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \
|
|
52
|
+
pdg.service.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
53
|
+
|
|
54
|
+
if torch_version_above_or_equal_2:
|
|
55
|
+
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
56
|
+
hook_handle_list.append(forward_hook_handle)
|
|
57
|
+
else:
|
|
58
|
+
pdg.service.check_register_full_backward_hook(module)
|
|
59
|
+
full_backward_hook_handle = module.register_full_backward_hook(
|
|
60
|
+
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
61
|
+
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
62
|
+
hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle])
|
|
63
|
+
pdg.service.check_register_full_backward_hook(module)
|
|
64
|
+
full_backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
65
|
+
|
|
66
|
+
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
67
|
+
pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
68
|
+
forward_hook_handle = module.register_forward_hook(
|
|
69
|
+
pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
70
|
+
hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle])
|
|
71
|
+
|
|
72
|
+
if torch_version_above_or_equal_2:
|
|
73
|
+
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
74
|
+
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
75
|
+
pdg.service.check_register_full_backward_hook(module)
|
|
76
|
+
full_backward_hook_handle = module.register_full_backward_hook(
|
|
77
|
+
pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
78
|
+
hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def remove_hook():
|
|
82
|
+
for hook_handle in hook_handle_list:
|
|
83
|
+
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
84
|
+
hook_handle.remove()
|
|
@@ -1,90 +1,91 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from collections import defaultdict
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
if int(torch.__version__.split('.')[0]) >= 2:
|
|
6
|
-
from torch.optim.optimizer import register_optimizer_step_pre_hook
|
|
7
|
-
from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
|
|
8
|
-
from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
|
|
9
|
-
from msprobe.core.grad_probe.constant import
|
|
10
|
-
from msprobe.
|
|
11
|
-
from msprobe.core.common.
|
|
12
|
-
from msprobe.
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
level
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
self.
|
|
23
|
-
self.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
self._bounds
|
|
29
|
-
|
|
30
|
-
self._output_path
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
self.
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
self._step
|
|
64
|
-
if not data_in_list_target(self._step, self._target_step):
|
|
65
|
-
return
|
|
66
|
-
output_lines = []
|
|
67
|
-
for param, param_name in self._param2name.items():
|
|
68
|
-
if not data_in_list_target(param_name, self._param_list):
|
|
69
|
-
continue
|
|
70
|
-
grad = param.main_grad if hasattr(param, "main_grad") else param.grad
|
|
71
|
-
if grad is None:
|
|
72
|
-
logger.info(f"grad is None: {param_name}")
|
|
73
|
-
continue
|
|
74
|
-
grad_info = GradStatCsv.generate_csv_line(param_name, self._level_adp, grad, self._bounds)
|
|
75
|
-
output_lines.append(grad_info)
|
|
76
|
-
if self._level_adp["have_grad_direction"]:
|
|
77
|
-
GradientMonitor.save_grad_direction(param_name, grad,
|
|
78
|
-
f'{self._output_path}/rank{self._rank}/step{self._step}')
|
|
79
|
-
output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
|
|
80
|
-
if not os.path.isdir(output_dirpath):
|
|
81
|
-
create_directory(output_dirpath)
|
|
82
|
-
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
83
|
-
if os.path.exists(output_path):
|
|
84
|
-
logger.warning(f"{output_path} will be recoverd")
|
|
85
|
-
remove_path(output_path)
|
|
86
|
-
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
87
|
-
output_lines.insert(0, header_result)
|
|
88
|
-
write_csv(output_lines, output_path)
|
|
89
|
-
|
|
90
|
-
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
if int(torch.__version__.split('.')[0]) >= 2:
|
|
6
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook
|
|
7
|
+
from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
|
|
8
|
+
from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
|
|
9
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
10
|
+
from msprobe.pytorch.common.log import logger
|
|
11
|
+
from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
|
|
12
|
+
from msprobe.pytorch.common.utils import get_rank_id, print_rank_0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GradientMonitor:
|
|
16
|
+
|
|
17
|
+
def __init__(self, common_config, task_config):
|
|
18
|
+
level = task_config.grad_level
|
|
19
|
+
if level not in level_adp:
|
|
20
|
+
raise Exception(f"level is valid, not in {level_adp.keys()}")
|
|
21
|
+
self._level_adp = level_adp[level]
|
|
22
|
+
self._param_list = task_config.param_list
|
|
23
|
+
self._target_ranks = common_config.rank
|
|
24
|
+
logger.info(f"target rank {self._target_ranks}")
|
|
25
|
+
self._target_step = common_config.step
|
|
26
|
+
logger.info(f"target step {self._target_step}")
|
|
27
|
+
self._bounds = task_config.bounds
|
|
28
|
+
check_numeral_list_ascend(self._bounds)
|
|
29
|
+
self._output_path = common_config.dump_path
|
|
30
|
+
if not os.path.exists(self._output_path):
|
|
31
|
+
create_directory(self._output_path)
|
|
32
|
+
else:
|
|
33
|
+
logger.warning(f"the file in {self._output_path} will be recoverd")
|
|
34
|
+
self._step = -1
|
|
35
|
+
self._param2name = defaultdict(str)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def output_path(self):
|
|
39
|
+
return self._output_path
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def save_grad_direction(param_name, grad, save_path):
|
|
43
|
+
if not os.path.exists(save_path):
|
|
44
|
+
create_directory(save_path)
|
|
45
|
+
param_grad = grad.clone().detach()
|
|
46
|
+
is_positive = param_grad > 0
|
|
47
|
+
save_filepath = os.path.join(save_path, f"{param_name}.npy")
|
|
48
|
+
save_npy(is_positive.cpu().numpy(), save_filepath)
|
|
49
|
+
|
|
50
|
+
def monitor(self, model):
|
|
51
|
+
print_rank_0("> parameter names:")
|
|
52
|
+
for name, param in model.named_parameters():
|
|
53
|
+
self._param2name[param] = name
|
|
54
|
+
print_rank_0(f"\t{name}")
|
|
55
|
+
setattr(self, "_rank", get_rank_id())
|
|
56
|
+
if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks):
|
|
57
|
+
return
|
|
58
|
+
self._hook_optimizer()
|
|
59
|
+
|
|
60
|
+
def _hook_optimizer(self):
|
|
61
|
+
def optimizer_pre_step_hook(optimizer, args, kargs):
|
|
62
|
+
self._step += 1
|
|
63
|
+
logger.info(f"grad_probe: optimizer step {self._step}")
|
|
64
|
+
if not data_in_list_target(self._step, self._target_step):
|
|
65
|
+
return
|
|
66
|
+
output_lines = []
|
|
67
|
+
for param, param_name in self._param2name.items():
|
|
68
|
+
if not data_in_list_target(param_name, self._param_list):
|
|
69
|
+
continue
|
|
70
|
+
grad = param.main_grad if hasattr(param, "main_grad") else param.grad
|
|
71
|
+
if grad is None:
|
|
72
|
+
logger.info(f"grad is None: {param_name}")
|
|
73
|
+
continue
|
|
74
|
+
grad_info = GradStatCsv.generate_csv_line(param_name, self._level_adp, grad, self._bounds)
|
|
75
|
+
output_lines.append(grad_info)
|
|
76
|
+
if self._level_adp["have_grad_direction"]:
|
|
77
|
+
GradientMonitor.save_grad_direction(param_name, grad,
|
|
78
|
+
f'{self._output_path}/rank{self._rank}/step{self._step}')
|
|
79
|
+
output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
|
|
80
|
+
if not os.path.isdir(output_dirpath):
|
|
81
|
+
create_directory(output_dirpath)
|
|
82
|
+
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
83
|
+
if os.path.exists(output_path):
|
|
84
|
+
logger.warning(f"{output_path} will be recoverd")
|
|
85
|
+
remove_path(output_path)
|
|
86
|
+
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
87
|
+
output_lines.insert(0, header_result)
|
|
88
|
+
write_csv(output_lines, output_path)
|
|
89
|
+
logger.info(f"write grad data to {output_path}")
|
|
90
|
+
if int(torch.__version__.split('.')[0]) >= 2:
|
|
91
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|