mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# 无标杆工具场景验证和性能基线报告
|
|
2
|
+
|
|
3
|
+
## 1 环境信息
|
|
4
|
+
|
|
5
|
+
NPU:Atlas A2 训练系列产品
|
|
6
|
+
|
|
7
|
+
CPU:
|
|
8
|
+
|
|
9
|
+

|
|
10
|
+
|
|
11
|
+
Torch:2.1.0
|
|
12
|
+
|
|
13
|
+
CANN:8.0.T5
|
|
14
|
+
|
|
15
|
+
除上述环境信息影响性能外,API 的数量、种类以及 Shape 都会对性能产生影响,因此本次选取不同场景网络和不同算子进行测试。
|
|
16
|
+
|
|
17
|
+
## 2 模型信息和性能基线
|
|
18
|
+
|
|
19
|
+
大模型在使用 msprobe 采集数据时,建议先简化模型层数,减少采集数据量。
|
|
20
|
+
|
|
21
|
+
以下场景的性能基线测试数据均为多次测试后取平均值,实际运行时性能数据可能会根据环境状态稍有浮动。
|
|
22
|
+
|
|
23
|
+
### [2.1 ModelLink 模型](https://gitee.com/ascend/ModelLink)
|
|
24
|
+
|
|
25
|
+
NUM_LAYER:1
|
|
26
|
+
|
|
27
|
+
NPU 卡数:1
|
|
28
|
+
|
|
29
|
+
主要数据类型:FLOAT16
|
|
30
|
+
|
|
31
|
+
#### 2.1.1 LLaMA2-7B
|
|
32
|
+
|
|
33
|
+
softmax 算子为 FLOAT32,输入输出大小均为 2G,为模型最大显存开销的 API。在该模型下、对无标杆工具处理模式、插装范围、扰动方式组合下的性能和显存基线进行验证。
|
|
34
|
+
|
|
35
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
36
|
+
|-------|--------|------|-----|-------|---------|--------|-------|--------|
|
|
37
|
+
| / | / | / | / | 0.24 | 13.69 | 1 | 1 | 混合精度模式基线 |
|
|
38
|
+
| check | 前 | ["softmax"] | improve_precision | 0.26 | 13.69 | 1.08 | 1 | softmax 本身为高精度,跳过 |
|
|
39
|
+
| check | 前 | ["softmax"] | add_noise | 0.54 | 19.17 | 2.25 | 1.40 | |
|
|
40
|
+
| check | 前 | ["softmax"] | bit_noise | 0.56 | 19.17 | 2.33 | 1.40 | |
|
|
41
|
+
| check | 前 | ["softmax"] | change_value | 0.48 | 14.9 | 2 | 1.09 | |
|
|
42
|
+
| check | 前 | ["softmax"] | no_change | 0.47 | 14.9 | 1.96 | 1.09 | |
|
|
43
|
+
| check | 前 | ["softmax"] | to_cpu | 26.45 | 22.67 | 110.21 | 1.66 | 不建议整网 |
|
|
44
|
+
| check | 前 | ["matmul"] | improve_precision | 0.57 | 13.69 | 2.38 | 1 | |
|
|
45
|
+
| check | 前 | ["matmul"] | change_value | 0.48 | 13.69 | 2 | 1 | |
|
|
46
|
+
| check | 前 | ["matmul"] | to_cpu | 78.43 | 19.20 | 326.79 | 1.40 | 不建议整网 |
|
|
47
|
+
| check | 前 | [] | improve_precision | 3.45 | 18.79 | 14.37 | 1.37 | |
|
|
48
|
+
| check | 前 | [] | add_noise | 4.67 | 19.17 | 19.46 | 1.40 | |
|
|
49
|
+
| check | 前 | [] | bit_noise | 16.99 | 19.17 | 70.79 | 1.40 | |
|
|
50
|
+
| check | 前 | [] | no_change | 3.22 | 14.90 | 13.42 | 1.09 | |
|
|
51
|
+
| check | 反 | ["softmax"] | improve_precision | 6.23 | 25.69 | 25.96 | 1.88 | 不建议整网 |
|
|
52
|
+
| check | 反 | ["softmax"] | change_value | 22.76 | 25.69 | 94.83 | 1.88 | 不建议整网 |
|
|
53
|
+
| check | 反 | ["softmax"] | to_cpu | 141.71 | 26.19 | 590.46 | 1.91 | 不建议整网 |
|
|
54
|
+
| fix | 前 | ["softmax"] | to_cpu | 9.70 | 16.67 | 40.42 | 1.22 | 不支持整网、不支持反向 |
|
|
55
|
+
| fix | 前 | ["softmax"] | improve_precision | 0.26 | 14.67 | 1.08 | 1.07 | 不支持整网、不支持反向 |
|
|
56
|
+
| 预热 | 前 | [] | improve_precision | 155.07 | 24.79 | 646.13 | 1.81 | 低精度模型基线、只测预热的迭代 |
|
|
57
|
+
| 预热 | 反 | [] | improve_precision | 72.29 | 22.01 | 301.21 | 1.61 | 低精度模型基线、只测预热的迭代,grad_output 为高精度的算子跳过 |
|
|
58
|
+
|
|
59
|
+
#### 2.1.2 Aquila2-7B
|
|
60
|
+
|
|
61
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
62
|
+
|----------|------|-----|---|----|-----|-------|------|-------------|
|
|
63
|
+
| / | / | / | / | 0.17 | 13.66 | 1 | 1 | 混合精度模式基线 |
|
|
64
|
+
| check | 前 | [] | improve_precision | 1.57 | 14.24 | 9.24 | 1.04 | |
|
|
65
|
+
| check | 反 | [] | add_noise | 21.05 | 14.19 | 123.82 | 1.04 | |
|
|
66
|
+
| fix | 前 | [] | improve_precision | 0.95 | 15.55 | 5.59 | 1.14 | |
|
|
67
|
+
|
|
68
|
+
#### 2.1.3 Baichuan2-7B
|
|
69
|
+
|
|
70
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s)| 显存峰值(GB)| 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
71
|
+
|----|-----|---|--|----|----|------|-------|---------|
|
|
72
|
+
| / | / | / | / | 0.26 | 12.12 | 1 | 1 | 混合精度模式基线 |
|
|
73
|
+
| check | 前 | [] | improve_precision | 1.02 | 12.27 | 3.92 | 1.01 | |
|
|
74
|
+
| check | 反 | [] | add_noise | 11.15 | 12.67 | 42.88 | 1.05 | |
|
|
75
|
+
| fix | 前 | [] | improve_precision | 0.95 | 12.82 | 3.65 | 1.06 | |
|
|
76
|
+
|
|
77
|
+
#### 2.1.4 Bloom-7B
|
|
78
|
+
|
|
79
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s)| 显存峰值(GB)| 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
80
|
+
|-----|------|------|------|----|-----|-----|-------|----|
|
|
81
|
+
| / | / | / | / | 0.14 | 9.51 | 1 | 1 | 混合精度模式基线 |
|
|
82
|
+
| check | 前 | [] | improve_precision | 1.64 | 11.58 | 11.71 | 1.22 | |
|
|
83
|
+
| check | 反 | [] | add_noise | 17.15 | 9.51 | 122.5 | 1 | |
|
|
84
|
+
| fix | 前 | [] | improve_precision | 0.87 | 10.62 | 6.21 | 1.12 | |
|
|
85
|
+
|
|
86
|
+
#### 2.1.5 Interlm-7B
|
|
87
|
+
|
|
88
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
89
|
+
|-------------|--------|-------|----|------|-----|------|-------|----|
|
|
90
|
+
| / | / | / | / | 0.13 | 10.76 | 1 | 1 | 混合精度模式基线 |
|
|
91
|
+
| check | 前 | [] | improve_precision | 1.19 | 11.68 | 9.15 | 1.09 | |
|
|
92
|
+
| check | 反 | [] | add_noise | 11.69 | 10.89 | 89.92 | 1.01 | |
|
|
93
|
+
| fix | 前 | [] | improve_precision | 0.75 | 11.68 | 5.77 | 1.09 | |
|
|
94
|
+
|
|
95
|
+
#### 2.1.6 Qwen-7B
|
|
96
|
+
|
|
97
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
98
|
+
|--------|-------|-----|-----|----|------|-----|------|------|
|
|
99
|
+
| / | / | / | / | 0.28 | 18.41 | 1 | 1 | 混合精度模式基线 |
|
|
100
|
+
| check | 前 | [] | improve_precision | 2.34 | 23.18 | 8.36 | 1.26 | |
|
|
101
|
+
| check | 反 | [] | add_noise | 22.07 | 19.47 | 78.82 | 1.06 | |
|
|
102
|
+
| fix | 前 | [] | improve_precision | 1.31 | 21.11 | 4.68 | 1.15 | |
|
|
103
|
+
|
|
104
|
+
#### 2.1.7 Gemma-7B
|
|
105
|
+
|
|
106
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
107
|
+
|--------|-------|------|---|----|-----|-----|-----|---------|
|
|
108
|
+
| / | / | / | / | 0.15 | 11.06 | 1 | 1 | 混合精度模式基线 |
|
|
109
|
+
| check | 前 | [] | improve_precision | 1.49 | 13.17 | 9.93 | 1.19 | |
|
|
110
|
+
| check | 反 | [] | add_noise | 16.69 | 11.06 | 111.27 | 1 | |
|
|
111
|
+
| fix | 前 | [] | improve_precision | 0.87 | 12.25 | 5.8 | 1.11 | |
|
|
112
|
+
|
|
113
|
+
### [2.2 ModelZoo-PyTorch 模型](https://gitee.com/ascend/ModelZoo-PyTorch)
|
|
114
|
+
|
|
115
|
+
#### 2.2.1 ResNet50-Cifar
|
|
116
|
+
|
|
117
|
+
NPU 卡数:1
|
|
118
|
+
|
|
119
|
+
主要数据类型:FLOAT16
|
|
120
|
+
|
|
121
|
+
主要算子为 conv2d,每个 step 有 51 个, 因此对 conv2d 进行检测。CV 模型、依赖 mmcv 实现(如果不修改 mmcv 代码、工具无法获取 step 信息和反向信息)。
|
|
122
|
+
|
|
123
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
124
|
+
|------------|---------|--------|-----|------|---|--------|-------|----|
|
|
125
|
+
| / | / | / | / | 0.09 | 7.63 | 1 | 1 | 基线 |
|
|
126
|
+
| check | 前 | ["conv2d"] | improve_precision | 0.889 | 7.94 | 9.81 | 1.04 | |
|
|
127
|
+
| fix | 前 | ["conv2d"] | improve_precision | 0.328 | 7.47 | 3.64 | 0.91 | |
|
|
128
|
+
| fix | 前 | ["conv2d"] | to_cpu | 12.23 | 7.47 | 135.88 | 0.91 | |
|
|
129
|
+
|
|
130
|
+
#### 2.2.2 OpenSora1.0
|
|
131
|
+
|
|
132
|
+
NPU 卡数:4
|
|
133
|
+
|
|
134
|
+
主要数据类型:FLOAT16
|
|
135
|
+
|
|
136
|
+
每张卡每个 step 中 linear 算子个数为 257 个,FA 算子个数为 83(FA 算子反向无效)。
|
|
137
|
+
|
|
138
|
+
| 处理模式 | 前/反向 | 算子范围 | 扰动方式 | 耗时(s) | 显存峰值(GB) | 耗时膨胀倍数 | 显存峰值膨胀倍数 | 备注 |
|
|
139
|
+
|------------|------|-------|----|----|-----|-----|------|-----|
|
|
140
|
+
| / | / | / | / | 0.99 | 17.61 | 1 | 1 | 混合精度模式基线 |
|
|
141
|
+
| check | 前 | ["linear","npu_fusion_attention"] | improve_precision | 3.88 | 17.61 | 3.92 | 1 | |
|
|
142
|
+
| check | 前 | ["linear","npu_fusion_attention"] | add_noise | 3.46 | 17.61 | 3.49 | 1 | |
|
|
143
|
+
| check | 反 | ["linear"] | improve_precision | 12.61 | 17.61 | 12.74 | 1 | |
|
|
144
|
+
| check | 反 | ["linear"] | add_noise | 9.8 | 17.61 | 9.90 | 1 | |
|
|
145
|
+
| fix | 前 | ["linear"] | to_cpu | 18.83 | 17.61 | 19.02 | 1 | |
|
|
146
|
+
| fix | 前 | ["linear"] | improve_precision | 2.83 | 17.61 | 2.86 | 1 | |
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
1
|
+
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
File without changes
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
|
|
5
|
+
from msprobe.core.common.utils import add_time_as_suffix
|
|
6
|
+
from msprobe.core.common.const import Const, CompareConst, MsCompareConst
|
|
7
|
+
from msprobe.mindspore.common.log import logger
|
|
8
|
+
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
9
|
+
from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
|
|
10
|
+
from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
|
|
11
|
+
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
12
|
+
trim_output_compute_element_list)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BasicInfoAndStatus:
|
|
16
|
+
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
17
|
+
self.api_name = api_name
|
|
18
|
+
self.bench_dtype = bench_dtype
|
|
19
|
+
self.tested_dtype = tested_dtype
|
|
20
|
+
self.shape = shape
|
|
21
|
+
self.status = status
|
|
22
|
+
self.err_msg = err_msg
|
|
23
|
+
|
|
24
|
+
class ResultCsvEntry:
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.forward_pass_status = None
|
|
27
|
+
self.backward_pass_status = None
|
|
28
|
+
self.forward_err_msg = ""
|
|
29
|
+
self.backward_err_msg = ""
|
|
30
|
+
self.overall_err_msg = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ApiAccuracyChecker:
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self.api_infos = dict()
|
|
36
|
+
self.results = dict()
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
40
|
+
'''
|
|
41
|
+
Args:
|
|
42
|
+
api_info: ApiInfo
|
|
43
|
+
api_name_str: str
|
|
44
|
+
api_input_aggregation: ApiInputAggregation
|
|
45
|
+
forward_or_backward: str: Union["forward", "backward"]
|
|
46
|
+
|
|
47
|
+
Return:
|
|
48
|
+
output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
|
|
49
|
+
|
|
50
|
+
Description:
|
|
51
|
+
get mindspore api output, run torch api and get output.
|
|
52
|
+
compare output.
|
|
53
|
+
record compare result.
|
|
54
|
+
'''
|
|
55
|
+
# get output
|
|
56
|
+
if global_context.get_is_constructed():
|
|
57
|
+
# constructed situation, need use constructed input to run mindspore api getting tested_output
|
|
58
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
|
|
59
|
+
else:
|
|
60
|
+
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
61
|
+
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
62
|
+
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
63
|
+
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
64
|
+
if len(tested_outputs) != len(bench_outputs):
|
|
65
|
+
logger.warning(f"ApiAccuracyChecker.run_and_compare_helper: api: {api_name_str}.{forward_or_backward}, "
|
|
66
|
+
"number of bench outputs and tested outputs is different, comparing result can be wrong. "
|
|
67
|
+
f"tested outputs: {len(tested_outputs)}, bench outputs: {len(bench_outputs)}")
|
|
68
|
+
|
|
69
|
+
# compare output
|
|
70
|
+
output_list = []
|
|
71
|
+
for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
|
|
72
|
+
api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
|
|
73
|
+
bench_dtype = bench_out.get_dtype()
|
|
74
|
+
tested_dtype = tested_out.get_dtype()
|
|
75
|
+
shape = bench_out.get_shape()
|
|
76
|
+
|
|
77
|
+
compare_result_dict = dict()
|
|
78
|
+
for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
|
|
79
|
+
compare_result = compare_algorithm(bench_out, tested_out)
|
|
80
|
+
compare_result_dict[compare_algorithm_name] = compare_result
|
|
81
|
+
|
|
82
|
+
if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
|
|
83
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
84
|
+
status = CompareConst.PASS
|
|
85
|
+
err_msg = ""
|
|
86
|
+
else:
|
|
87
|
+
status = CompareConst.ERROR
|
|
88
|
+
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
|
|
89
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
|
|
90
|
+
basic_info_status = \
|
|
91
|
+
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
92
|
+
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
93
|
+
return output_list
|
|
94
|
+
|
|
95
|
+
def parse(self, api_info_path):
|
|
96
|
+
with FileOpen(api_info_path, "r") as f:
|
|
97
|
+
api_info_dict = json.load(f)
|
|
98
|
+
|
|
99
|
+
# init global context
|
|
100
|
+
task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
|
|
101
|
+
"task field in api_info.json",accepted_type=str,
|
|
102
|
+
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
103
|
+
MsCompareConst.TENSOR_TASK))
|
|
104
|
+
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
105
|
+
if not is_constructed:
|
|
106
|
+
dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
|
|
107
|
+
"dump_data_dir field in api_info.json", accepted_type=str)
|
|
108
|
+
else:
|
|
109
|
+
dump_data_dir = ""
|
|
110
|
+
global_context.init(is_constructed, dump_data_dir)
|
|
111
|
+
|
|
112
|
+
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
113
|
+
"data field in api_info.json", accepted_type=dict)
|
|
114
|
+
for api_name, api_info in api_info_data.items():
|
|
115
|
+
is_mint = api_name.split(Const.SEP)[0] in \
|
|
116
|
+
(MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
117
|
+
if not is_mint:
|
|
118
|
+
continue
|
|
119
|
+
forbackward_str = api_name.split(Const.SEP)[-1]
|
|
120
|
+
if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
|
|
121
|
+
logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
|
|
122
|
+
api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
|
|
123
|
+
if api_name not in self.api_infos:
|
|
124
|
+
self.api_infos[api_name] = ApiInfo(api_name)
|
|
125
|
+
|
|
126
|
+
if forbackward_str == Const.FORWARD:
|
|
127
|
+
self.api_infos[api_name].load_forward_info(api_info)
|
|
128
|
+
else:
|
|
129
|
+
self.api_infos[api_name].load_backward_info(api_info)
|
|
130
|
+
|
|
131
|
+
def run_and_compare(self):
|
|
132
|
+
for api_name_str, api_info in self.api_infos.items():
|
|
133
|
+
if not api_info.check_forward_info():
|
|
134
|
+
logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
|
|
135
|
+
continue
|
|
136
|
+
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
137
|
+
kwargs = api_info.get_kwargs()
|
|
138
|
+
forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
|
|
139
|
+
forward_output_list = None
|
|
140
|
+
try:
|
|
141
|
+
forward_output_list = \
|
|
142
|
+
self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
|
|
145
|
+
f"detailed exception information: {e}")
|
|
146
|
+
self.record(forward_output_list)
|
|
147
|
+
|
|
148
|
+
if not api_info.check_backward_info():
|
|
149
|
+
logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
|
|
150
|
+
continue
|
|
151
|
+
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
152
|
+
backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
153
|
+
backward_output_list = None
|
|
154
|
+
try:
|
|
155
|
+
backward_output_list = \
|
|
156
|
+
self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
|
|
159
|
+
f"detailed exception information: {e}")
|
|
160
|
+
self.record(backward_output_list)
|
|
161
|
+
|
|
162
|
+
def record(self, output_list):
|
|
163
|
+
if output_list is None:
|
|
164
|
+
return
|
|
165
|
+
for output in output_list:
|
|
166
|
+
api_real_name, forward_or_backward, basic_info, compare_result_dict = output
|
|
167
|
+
key = tuple([api_real_name, forward_or_backward])
|
|
168
|
+
if key not in self.results:
|
|
169
|
+
self.results[key] = []
|
|
170
|
+
self.results[key].append(tuple([basic_info, compare_result_dict]))
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def to_detail_csv(self, csv_dir):
|
|
174
|
+
# detail_csv
|
|
175
|
+
detail_csv = []
|
|
176
|
+
detail_csv_header_basic_info = [
|
|
177
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
178
|
+
MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
|
|
179
|
+
MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
|
|
180
|
+
MsCompareConst.DETAIL_CSV_SHAPE,
|
|
181
|
+
]
|
|
182
|
+
detail_csv_header_compare_result = list(compare_algorithms.keys())
|
|
183
|
+
detail_csv_header_status = [
|
|
184
|
+
MsCompareConst.DETAIL_CSV_PASS_STATUS,
|
|
185
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
|
|
189
|
+
detail_csv.append(detail_csv_header)
|
|
190
|
+
|
|
191
|
+
for _, results in self.results.items():
|
|
192
|
+
# detail csv
|
|
193
|
+
for res in results:
|
|
194
|
+
basic_info, compare_result_dict = res
|
|
195
|
+
csv_row_basic_info = \
|
|
196
|
+
[basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
|
|
197
|
+
csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
|
|
198
|
+
for algorithm_name in detail_csv_header_compare_result)
|
|
199
|
+
csv_row_status = [basic_info.status, basic_info.err_msg]
|
|
200
|
+
csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
|
|
201
|
+
detail_csv.append(csv_row)
|
|
202
|
+
|
|
203
|
+
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
|
|
204
|
+
create_directory(csv_dir)
|
|
205
|
+
write_csv(detail_csv, file_name, mode="w")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def to_result_csv(self, csv_dir):
|
|
209
|
+
result_csv_dict = dict()
|
|
210
|
+
for key, results in self.results.items():
|
|
211
|
+
api_real_name, forward_or_backward = key
|
|
212
|
+
forward_or_backward_pass_status = CompareConst.PASS
|
|
213
|
+
forward_or_backward_overall_err_msg = ""
|
|
214
|
+
# detail csv
|
|
215
|
+
for res in results:
|
|
216
|
+
basic_info, _ = res
|
|
217
|
+
if basic_info.status != CompareConst.PASS:
|
|
218
|
+
forward_or_backward_pass_status = CompareConst.ERROR
|
|
219
|
+
forward_or_backward_overall_err_msg += basic_info.err_msg
|
|
220
|
+
forward_or_backward_overall_err_msg = \
|
|
221
|
+
"" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
|
|
222
|
+
|
|
223
|
+
#result_csv_dict
|
|
224
|
+
if api_real_name not in result_csv_dict:
|
|
225
|
+
result_csv_dict[api_real_name] = ResultCsvEntry()
|
|
226
|
+
if forward_or_backward == Const.FORWARD:
|
|
227
|
+
result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
|
|
228
|
+
result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
|
|
229
|
+
else:
|
|
230
|
+
result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
|
|
231
|
+
result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
|
|
232
|
+
|
|
233
|
+
#result_csv
|
|
234
|
+
result_csv = []
|
|
235
|
+
result_csv_header = [
|
|
236
|
+
MsCompareConst.DETAIL_CSV_API_NAME,
|
|
237
|
+
MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
|
|
238
|
+
MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
|
|
239
|
+
MsCompareConst.DETAIL_CSV_MESSAGE,
|
|
240
|
+
]
|
|
241
|
+
result_csv.append(result_csv_header)
|
|
242
|
+
|
|
243
|
+
for api_name, result_csv_entry in result_csv_dict.items():
|
|
244
|
+
if result_csv_entry.forward_pass_status == CompareConst.PASS and \
|
|
245
|
+
result_csv_entry.backward_pass_status == CompareConst.PASS:
|
|
246
|
+
overall_err_msg = ""
|
|
247
|
+
else:
|
|
248
|
+
overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
|
|
249
|
+
row = [api_name, result_csv_entry.forward_pass_status,
|
|
250
|
+
result_csv_entry.backward_pass_status, overall_err_msg]
|
|
251
|
+
result_csv.append(row)
|
|
252
|
+
|
|
253
|
+
file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
|
|
254
|
+
create_directory(csv_dir)
|
|
255
|
+
write_csv(result_csv, file_name, mode="w")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
|
|
4
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
5
|
+
from msprobe.mindspore.common.log import logger
|
|
6
|
+
|
|
7
|
+
class ApiInfo:
|
|
8
|
+
def __init__(self, api_name):
|
|
9
|
+
self.api_name = api_name
|
|
10
|
+
self.forward_info = None
|
|
11
|
+
self.backward_info = None
|
|
12
|
+
|
|
13
|
+
def load_forward_info(self, forward_info_dict):
|
|
14
|
+
self.forward_info = forward_info_dict
|
|
15
|
+
|
|
16
|
+
def load_backward_info(self, backward_info_dict):
|
|
17
|
+
self.backward_info = backward_info_dict
|
|
18
|
+
|
|
19
|
+
def check_forward_info(self):
|
|
20
|
+
return self.forward_info is not None
|
|
21
|
+
|
|
22
|
+
def check_backward_info(self):
|
|
23
|
+
return self.backward_info is not None
|
|
24
|
+
|
|
25
|
+
def get_compute_element_list(self, forward_or_backward, input_or_output):
|
|
26
|
+
'''
|
|
27
|
+
Args:
|
|
28
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
29
|
+
input_or_output: str, Union["input", "output"]
|
|
30
|
+
|
|
31
|
+
Return:
|
|
32
|
+
compute_element_list: List[ComputeElement]
|
|
33
|
+
'''
|
|
34
|
+
mapping = {
|
|
35
|
+
(Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
|
|
36
|
+
f"input_args field of {self.api_name} forward api in api_info.json"],
|
|
37
|
+
(Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
|
|
38
|
+
f"output field of {self.api_name} forward api in api_info.json"],
|
|
39
|
+
(Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
|
|
40
|
+
f"input field of {self.api_name} backward api in api_info.json"],
|
|
41
|
+
(Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
|
|
42
|
+
f"output field of {self.api_name} backward api in api_info.json"]
|
|
43
|
+
}
|
|
44
|
+
dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
|
|
45
|
+
compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
|
|
46
|
+
compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
|
|
47
|
+
for compute_element_info in compute_element_info_list]
|
|
48
|
+
return compute_element_list
|
|
49
|
+
|
|
50
|
+
def get_kwargs(self):
|
|
51
|
+
'''
|
|
52
|
+
Return:
|
|
53
|
+
kwargs_compute_element_dict: dict{str: ComputeElement}
|
|
54
|
+
'''
|
|
55
|
+
kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
|
|
56
|
+
"input_kwargs in api_info.json", accepted_type=dict)
|
|
57
|
+
for key_str, compute_element_info in kwargs_dict.items():
|
|
58
|
+
if not isinstance(key_str, str):
|
|
59
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
|
|
60
|
+
logger.error_log_with_exp(err_msg,
|
|
61
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
62
|
+
if not isinstance(compute_element_info, (list, dict)):
|
|
63
|
+
err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
|
|
64
|
+
logger.error_log_with_exp(err_msg,
|
|
65
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
|
|
66
|
+
kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
|
|
67
|
+
for key_str, compute_element_info in kwargs_dict.items()}
|
|
68
|
+
return kwargs_compute_element_dict
|
|
69
|
+
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
from mindspore import ops
|
|
6
|
+
|
|
7
|
+
from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
|
|
8
|
+
from msprobe.core.common.const import Const, MsCompareConst
|
|
9
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
10
|
+
from msprobe.mindspore.common.log import logger
|
|
11
|
+
from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
12
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ApiInputAggregation:
|
|
16
|
+
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
17
|
+
'''
|
|
18
|
+
Args:
|
|
19
|
+
inputs: List[ComputeElement]
|
|
20
|
+
kwargs: dict{str: ComputeElement}
|
|
21
|
+
gradient_inputs: Union[List[ComputeElement], None]
|
|
22
|
+
'''
|
|
23
|
+
self.inputs = inputs
|
|
24
|
+
self.kwargs = kwargs
|
|
25
|
+
self.gradient_inputs = gradient_inputs
|
|
26
|
+
|
|
27
|
+
api_parent_module_mapping = {
|
|
28
|
+
(MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
|
|
29
|
+
(MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
|
|
30
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
31
|
+
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ApiRunner:
|
|
36
|
+
def __call__(self, api_input_aggregation, api_name_str, forward_or_backward=Const.FORWARD,
|
|
37
|
+
api_platform=Const.MS_FRAMEWORK):
|
|
38
|
+
'''
|
|
39
|
+
Args:
|
|
40
|
+
api_input_aggregation: ApiInputAggregation
|
|
41
|
+
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
42
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
43
|
+
api_platform: str, Union["mindspore", "torch"]
|
|
44
|
+
|
|
45
|
+
Return:
|
|
46
|
+
outputs: list[ComputeElement]
|
|
47
|
+
|
|
48
|
+
Description:
|
|
49
|
+
run mindspore.mint/torch api
|
|
50
|
+
'''
|
|
51
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
|
|
52
|
+
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
53
|
+
|
|
54
|
+
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def get_info_from_name(api_name_str):
|
|
58
|
+
'''
|
|
59
|
+
Args:
|
|
60
|
+
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
61
|
+
|
|
62
|
+
Return:
|
|
63
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
64
|
+
api_sub_name: str, e.g. "relu"
|
|
65
|
+
'''
|
|
66
|
+
api_name_list = api_name_str.split(Const.SEP)
|
|
67
|
+
if len(api_name_list) != 3:
|
|
68
|
+
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
69
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
70
|
+
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
71
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
|
|
72
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
|
|
73
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
74
|
+
|
|
75
|
+
return api_type_str, api_sub_name
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
79
|
+
'''
|
|
80
|
+
Args:
|
|
81
|
+
api_type_str: str, Union["MintFunctional", "Mint"]
|
|
82
|
+
api_sub_name: str, e.g. "relu"
|
|
83
|
+
api_platform: str: Union["mindpore", "torch"]
|
|
84
|
+
|
|
85
|
+
Return:
|
|
86
|
+
api_instance: function object
|
|
87
|
+
|
|
88
|
+
Description:
|
|
89
|
+
get mindspore.mint/torch api fucntion
|
|
90
|
+
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
91
|
+
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
92
|
+
'''
|
|
93
|
+
|
|
94
|
+
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
95
|
+
module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
|
|
96
|
+
submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
|
|
97
|
+
full_api_name = module_str + submodule_str + api_sub_name
|
|
98
|
+
if not hasattr(api_parent_module, api_sub_name):
|
|
99
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
100
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
101
|
+
|
|
102
|
+
api_instance = getattr(api_parent_module, api_sub_name)
|
|
103
|
+
if not callable(api_instance):
|
|
104
|
+
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not callable"
|
|
105
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
106
|
+
|
|
107
|
+
return api_instance
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform):
|
|
111
|
+
inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
112
|
+
for compute_element in api_input_aggregation.inputs)
|
|
113
|
+
kwargs = {key: value.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
114
|
+
for key, value in api_input_aggregation.kwargs.items()}
|
|
115
|
+
gradient_inputs = api_input_aggregation.gradient_inputs
|
|
116
|
+
|
|
117
|
+
if forward_or_backward == Const.FORWARD:
|
|
118
|
+
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
119
|
+
forward_result_tuple = convert_to_tuple(forward_result)
|
|
120
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
121
|
+
else:
|
|
122
|
+
if gradient_inputs is None:
|
|
123
|
+
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
124
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
125
|
+
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
126
|
+
for compute_element in gradient_inputs)
|
|
127
|
+
if api_platform == Const.MS_FRAMEWORK:
|
|
128
|
+
if len(gradient_inputs) == 1:
|
|
129
|
+
gradient_inputs = gradient_inputs[0]
|
|
130
|
+
def api_with_kwargs(*forward_inputs):
|
|
131
|
+
return api_instance(*forward_inputs, **kwargs)
|
|
132
|
+
grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
|
|
133
|
+
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
134
|
+
backward_result_tuple = convert_to_tuple(backward_result)
|
|
135
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
136
|
+
else:
|
|
137
|
+
#set requires_grad
|
|
138
|
+
requires_grad_index = []
|
|
139
|
+
for index, tensor in enumerate(inputs):
|
|
140
|
+
if isinstance(tensor, torch.Tensor) and \
|
|
141
|
+
torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
|
|
142
|
+
setattr(tensor, "requires_grad", True)
|
|
143
|
+
requires_grad_index.append(index)
|
|
144
|
+
forward_results = api_instance(*inputs, **kwargs)
|
|
145
|
+
forward_results = convert_to_tuple(forward_results)
|
|
146
|
+
for forward_res, gradient_in in zip(forward_results, gradient_inputs):
|
|
147
|
+
forward_res.backward(gradient_in)
|
|
148
|
+
backward_result_list = []
|
|
149
|
+
for index in requires_grad_index:
|
|
150
|
+
backward_result_list.append(getattr(inputs[index], "grad"))
|
|
151
|
+
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_list]
|
|
152
|
+
|
|
153
|
+
return res_compute_element_list
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
api_runner = ApiRunner()
|