mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/pytorch/compare/match.py
CHANGED
|
@@ -1,36 +1,34 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import
|
|
3
|
-
from msprobe.core.common.
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
return self.match_op(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
if matching_op
|
|
30
|
-
return
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
graph_mapping = AtenIrMapping()
|
|
1
|
+
import os
|
|
2
|
+
from msprobe.core.common.utils import CompareException
|
|
3
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AtenIrMapping():
|
|
7
|
+
def __init__(self):
|
|
8
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
9
|
+
yaml_path = os.path.join(cur_path, "mapping.yaml")
|
|
10
|
+
self.aten_mapping = load_yaml(yaml_path)
|
|
11
|
+
|
|
12
|
+
def match(self, op1, op2):
|
|
13
|
+
if "Aten" in op1 and "Aten" not in op2:
|
|
14
|
+
return self.match_op(op1, op2)
|
|
15
|
+
else:
|
|
16
|
+
return self.match_op(op2, op1)
|
|
17
|
+
|
|
18
|
+
def match_op(self, aten_op, torch_op):
|
|
19
|
+
try:
|
|
20
|
+
aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
|
|
21
|
+
aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
|
|
22
|
+
torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
|
|
23
|
+
except IndexError as e:
|
|
24
|
+
err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
|
|
25
|
+
raise CompareException.INVALID_DATA_ERROR(err_msg) from e
|
|
26
|
+
matching_op = self.aten_mapping.get(aten_op_raw_name)
|
|
27
|
+
if matching_op is None:
|
|
28
|
+
return False
|
|
29
|
+
if matching_op.lower() == torch_op_raw_name:
|
|
30
|
+
return True
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
graph_mapping = AtenIrMapping()
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
import torch
|
|
3
|
+
from msprobe.core.common.const import FileCheckConst
|
|
4
|
+
from msprobe.pytorch.common.log import logger
|
|
5
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
6
|
+
from msprobe.core.compare.acc_compare import Comparator
|
|
7
|
+
from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException
|
|
8
|
+
from msprobe.core.common.file_utils import FileChecker, create_directory
|
|
9
|
+
from msprobe.pytorch.common.utils import load_pt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PTComparator (Comparator):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.frame_name = PTComparator.__name__
|
|
15
|
+
|
|
16
|
+
def read_npy_data(self, dir_path, file_name):
|
|
17
|
+
data_path = os.path.join(dir_path, file_name)
|
|
18
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
19
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
20
|
+
data_path = path_checker.common_check()
|
|
21
|
+
try:
|
|
22
|
+
data_value = load_pt(data_path,
|
|
23
|
+
to_cpu=True).detach() # detach because numpy can not process gradient information
|
|
24
|
+
except RuntimeError as e:
|
|
25
|
+
# 这里捕获 load_pt 中抛出的异常
|
|
26
|
+
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
27
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
28
|
+
except AttributeError as e:
|
|
29
|
+
# 这里捕获 detach 方法抛出的异常
|
|
30
|
+
logger.error(f"Failed to detach the loaded tensor.")
|
|
31
|
+
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
32
|
+
if data_value.dtype == torch.bfloat16:
|
|
33
|
+
data_value = data_value.to(torch.float32)
|
|
34
|
+
data_value = data_value.numpy()
|
|
35
|
+
return data_value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
|
|
39
|
+
try:
|
|
40
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
41
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
42
|
+
create_directory(output_path)
|
|
43
|
+
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
44
|
+
except (CompareException, FileCheckException) as error:
|
|
45
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
46
|
+
raise CompareException(error.code) from error
|
|
47
|
+
pt_comparator = PTComparator()
|
|
48
|
+
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
49
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
50
|
+
md5_compare=md5_compare)
|
|
@@ -1,86 +1,95 @@
|
|
|
1
|
-
from msprobe.pytorch.common import seed_all
|
|
2
|
-
from msprobe.pytorch.common.log import logger
|
|
3
|
-
from msprobe.core.common.const import Const
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class DebuggerConfig:
|
|
7
|
-
def __init__(self, common_config, task_config, task, dump_path, level):
|
|
8
|
-
self.dump_path = dump_path if dump_path else common_config.dump_path
|
|
9
|
-
self.task = task or common_config.task or Const.STATISTICS
|
|
10
|
-
self.rank = common_config.rank if common_config.rank else []
|
|
11
|
-
self.step = common_config.step if common_config.step else []
|
|
12
|
-
self.level = level or common_config.level or "L1"
|
|
13
|
-
self.seed = common_config.seed if common_config.seed else 1234
|
|
14
|
-
self.is_deterministic = common_config.is_deterministic
|
|
15
|
-
self.enable_dataloader = common_config.enable_dataloader
|
|
16
|
-
self.scope = task_config.scope if task_config.scope else []
|
|
17
|
-
self.list = task_config.list if task_config.list else []
|
|
18
|
-
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
19
|
-
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
20
|
-
self.backward_input = {}
|
|
21
|
-
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
22
|
-
self.is_forward_acl_dump = True
|
|
23
|
-
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
24
|
-
self.
|
|
25
|
-
self.framework = Const.PT_FRAMEWORK
|
|
26
|
-
|
|
27
|
-
if self.task == Const.FREE_BENCHMARK:
|
|
28
|
-
self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
|
|
29
|
-
self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
|
|
30
|
-
self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
|
|
31
|
-
self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
|
|
32
|
-
self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
|
|
33
|
-
self.preheat_config = {
|
|
34
|
-
"if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
|
|
35
|
-
"preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
|
|
36
|
-
"max_sample": task_config.max_sample if task_config.max_sample else 20,
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
self.
|
|
40
|
-
if self.
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
1
|
+
from msprobe.pytorch.common import seed_all
|
|
2
|
+
from msprobe.pytorch.common.log import logger
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DebuggerConfig:
|
|
7
|
+
def __init__(self, common_config, task_config, task, dump_path, level):
|
|
8
|
+
self.dump_path = dump_path if dump_path else common_config.dump_path
|
|
9
|
+
self.task = task or common_config.task or Const.STATISTICS
|
|
10
|
+
self.rank = common_config.rank if common_config.rank else []
|
|
11
|
+
self.step = common_config.step if common_config.step else []
|
|
12
|
+
self.level = level or common_config.level or "L1"
|
|
13
|
+
self.seed = common_config.seed if common_config.seed else 1234
|
|
14
|
+
self.is_deterministic = common_config.is_deterministic
|
|
15
|
+
self.enable_dataloader = common_config.enable_dataloader
|
|
16
|
+
self.scope = task_config.scope if task_config.scope else []
|
|
17
|
+
self.list = task_config.list if task_config.list else []
|
|
18
|
+
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
19
|
+
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
20
|
+
self.backward_input = {}
|
|
21
|
+
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
22
|
+
self.is_forward_acl_dump = True
|
|
23
|
+
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
24
|
+
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
25
|
+
self.framework = Const.PT_FRAMEWORK
|
|
26
|
+
|
|
27
|
+
if self.task == Const.FREE_BENCHMARK:
|
|
28
|
+
self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
|
|
29
|
+
self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
|
|
30
|
+
self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
|
|
31
|
+
self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
|
|
32
|
+
self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
|
|
33
|
+
self.preheat_config = {
|
|
34
|
+
"if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
|
|
35
|
+
"preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
|
|
36
|
+
"max_sample": task_config.max_sample if task_config.max_sample else 20,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
self.online_run_ut = False
|
|
40
|
+
if self.task == Const.TENSOR:
|
|
41
|
+
# dump api tensor and collaborate with online run_ut
|
|
42
|
+
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
43
|
+
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
44
|
+
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
45
|
+
self.host = task_config.host if task_config.host else ""
|
|
46
|
+
self.port = task_config.port if task_config.port else -1
|
|
47
|
+
|
|
48
|
+
self.check()
|
|
49
|
+
if self.step:
|
|
50
|
+
self.step.sort()
|
|
51
|
+
if self.level == "L2":
|
|
52
|
+
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
53
|
+
raise ValueError("scope must be configured as a list with one api name")
|
|
54
|
+
if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
|
|
55
|
+
raise ValueError("backward_input must be configured when scope contains 'backward'")
|
|
56
|
+
if Const.BACKWARD in self.scope[0]:
|
|
57
|
+
self.is_forward_acl_dump = False
|
|
58
|
+
for index, scope_spec in enumerate(self.scope):
|
|
59
|
+
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
60
|
+
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
61
|
+
seed_all(self.seed, self.is_deterministic)
|
|
62
|
+
|
|
63
|
+
def check_kwargs(self):
|
|
64
|
+
if self.task and self.task not in Const.TASK_LIST:
|
|
65
|
+
raise Exception("task is invalid")
|
|
66
|
+
if self.level and self.level not in Const.LEVEL_LIST:
|
|
67
|
+
raise Exception("level is invalid")
|
|
68
|
+
if not self.dump_path:
|
|
69
|
+
raise Exception("Invalid dump path, please check your config")
|
|
70
|
+
|
|
71
|
+
def check(self):
|
|
72
|
+
self.check_kwargs()
|
|
73
|
+
self._check_rank()
|
|
74
|
+
self._check_step()
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def check_model(self, model):
|
|
78
|
+
if self.level in ["L0", "mix"] and not model:
|
|
79
|
+
raise Exception(
|
|
80
|
+
f"For level {self.level}, PrecisionDebugger must receive a model argument."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _check_rank(self):
|
|
84
|
+
if self.rank:
|
|
85
|
+
for rank_id in self.rank:
|
|
86
|
+
if not isinstance(rank_id, int) or rank_id < 0:
|
|
87
|
+
raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
|
|
88
|
+
else:
|
|
89
|
+
logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
|
|
90
|
+
|
|
91
|
+
def _check_step(self):
|
|
92
|
+
if self.step:
|
|
93
|
+
for s in self.step:
|
|
94
|
+
if not isinstance(s, int) or s < 0:
|
|
95
|
+
raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
|
|
@@ -1,95 +1,125 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import dataloader
|
|
3
|
-
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
4
|
-
from msprobe.pytorch.service import Service
|
|
5
|
-
from msprobe.pytorch.common.log import logger
|
|
6
|
-
from msprobe.pytorch.pt_config import parse_json_config
|
|
7
|
-
from msprobe.core.common.exceptions import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
self.
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
self.
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
if
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if not instance:
|
|
71
|
-
raise Exception("
|
|
72
|
-
if instance.enable_dataloader:
|
|
73
|
-
logger.warning_on_rank_0("DataLoader is enabled,
|
|
74
|
-
else:
|
|
75
|
-
instance.service.
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
cls._instance
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
if
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import dataloader
|
|
3
|
+
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
4
|
+
from msprobe.pytorch.service import Service
|
|
5
|
+
from msprobe.pytorch.common.log import logger
|
|
6
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
7
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
8
|
+
from msprobe.core.common.const import Const
|
|
9
|
+
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PrecisionDebugger:
|
|
13
|
+
_instance = None
|
|
14
|
+
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
15
|
+
|
|
16
|
+
def __new__(cls, *args, **kwargs):
|
|
17
|
+
if cls._instance is None:
|
|
18
|
+
cls._instance = super(PrecisionDebugger, cls).__new__(cls)
|
|
19
|
+
cls._instance.config = None
|
|
20
|
+
cls._instance.enable_dataloader = False
|
|
21
|
+
return cls._instance
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
config_path=None,
|
|
26
|
+
task=None,
|
|
27
|
+
dump_path=None,
|
|
28
|
+
level=None,
|
|
29
|
+
model=None,
|
|
30
|
+
step=None,
|
|
31
|
+
):
|
|
32
|
+
if not hasattr(self, "initialized"):
|
|
33
|
+
self.api_origin = False
|
|
34
|
+
self.initialized = True
|
|
35
|
+
self.model = self.check_model_valid(model)
|
|
36
|
+
common_config, task_config = parse_json_config(config_path, task)
|
|
37
|
+
self.task = common_config.task
|
|
38
|
+
if self.task == Const.GRAD_PROBE:
|
|
39
|
+
self.gm = GradientMonitor(common_config, task_config)
|
|
40
|
+
return
|
|
41
|
+
if step:
|
|
42
|
+
common_config.step = step
|
|
43
|
+
self.config = DebuggerConfig(
|
|
44
|
+
common_config, task_config, task, dump_path, level
|
|
45
|
+
)
|
|
46
|
+
self.config.check_model(self.model)
|
|
47
|
+
self.service = Service(self.config)
|
|
48
|
+
self.enable_dataloader = self.config.enable_dataloader
|
|
49
|
+
if self.enable_dataloader:
|
|
50
|
+
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
51
|
+
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def instance(self):
|
|
55
|
+
return self._instance
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def check_model_valid(model):
|
|
59
|
+
if not model or isinstance(model, torch.nn.Module):
|
|
60
|
+
return model
|
|
61
|
+
raise MsprobeException(
|
|
62
|
+
MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def start(cls):
|
|
67
|
+
instance = cls._instance
|
|
68
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
69
|
+
return
|
|
70
|
+
if not instance:
|
|
71
|
+
raise Exception("No instance of PrecisionDebugger found.")
|
|
72
|
+
if instance.enable_dataloader:
|
|
73
|
+
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
74
|
+
else:
|
|
75
|
+
instance.service.start(instance.model, instance.api_origin)
|
|
76
|
+
instance.api_origin = False
|
|
77
|
+
|
|
78
|
+
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
79
|
+
@classmethod
|
|
80
|
+
def forward_backward_dump_end(cls):
|
|
81
|
+
instance = cls._instance
|
|
82
|
+
instance.service.forward_backward_dump_end()
|
|
83
|
+
instance.api_origin = True
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def stop(cls):
|
|
87
|
+
instance = cls._instance
|
|
88
|
+
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
89
|
+
return
|
|
90
|
+
if not instance:
|
|
91
|
+
raise Exception("PrecisionDebugger instance is not created.")
|
|
92
|
+
if instance.enable_dataloader:
|
|
93
|
+
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
94
|
+
else:
|
|
95
|
+
instance.service.stop()
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def step(cls):
|
|
99
|
+
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
100
|
+
return
|
|
101
|
+
if not cls._instance:
|
|
102
|
+
raise Exception("PrecisionDebugger instance is not created.")
|
|
103
|
+
cls._instance.service.step()
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def monitor(cls, model):
|
|
107
|
+
if not cls._instance:
|
|
108
|
+
raise Exception("PrecisionDebugger instance is not created.")
|
|
109
|
+
if cls._instance.task != Const.GRAD_PROBE:
|
|
110
|
+
return
|
|
111
|
+
cls._instance.gm.monitor(model)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def iter_tracer(func):
|
|
115
|
+
def func_wrapper(*args, **kwargs):
|
|
116
|
+
debugger_instance = PrecisionDebugger.instance
|
|
117
|
+
debugger_instance.enable_dataloader = False
|
|
118
|
+
if not debugger_instance.service.first_start:
|
|
119
|
+
debugger_instance.stop()
|
|
120
|
+
debugger_instance.step()
|
|
121
|
+
result = func(*args, **kwargs)
|
|
122
|
+
debugger_instance.start()
|
|
123
|
+
debugger_instance.enable_dataloader = True
|
|
124
|
+
return result
|
|
125
|
+
return func_wrapper
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from msprobe.
|
|
2
|
-
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
3
|
-
from msprobe.core.common.const import Const
|
|
4
|
-
|
|
5
|
-
from .main import FreeBenchmarkCheck
|
|
6
|
-
from .common.params import UnequalRow
|
|
7
|
-
|
|
8
|
-
__all__ = [FreeBenchmarkCheck, UnequalRow]
|
|
1
|
+
from msprobe.pytorch.common.log import logger
|
|
2
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
|
|
5
|
+
from .main import FreeBenchmarkCheck
|
|
6
|
+
from .common.params import UnequalRow
|
|
7
|
+
|
|
8
|
+
__all__ = [FreeBenchmarkCheck, UnequalRow]
|