mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -1,516 +1,385 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
"""
|
|
4
|
-
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
"""
|
|
17
|
-
import collections
|
|
18
|
-
import os
|
|
19
|
-
import re
|
|
20
|
-
import
|
|
21
|
-
import
|
|
22
|
-
import
|
|
23
|
-
import
|
|
24
|
-
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
raise
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
if
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def
|
|
124
|
-
if not
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def
|
|
175
|
-
""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
if not
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
raise
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
def
|
|
261
|
-
""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
""
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
logger.info("Rank id is not provided. Trying to get the rank id of the model.")
|
|
387
|
-
try:
|
|
388
|
-
local_device = next(model.parameters()).device
|
|
389
|
-
except StopIteration:
|
|
390
|
-
logger.warning('There is no parameter in the model. Fail to get rank id.')
|
|
391
|
-
return 0, False
|
|
392
|
-
if local_device.type == 'cpu':
|
|
393
|
-
logger.warning("Warning: the debugger is unable to get the rank id. "
|
|
394
|
-
"This may cause the dumpped data to be corrupted in the "
|
|
395
|
-
"case of distributed training. (You may ignore this if you are using only one card.) "
|
|
396
|
-
"Transfer the model to npu or gpu before register_hook() to avoid this warning.")
|
|
397
|
-
return 0, False
|
|
398
|
-
else:
|
|
399
|
-
return local_device.index, True
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
|
|
403
|
-
template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
|
|
404
|
-
pkl_dir = os.path.dirname(pkl_file_path)
|
|
405
|
-
compare_script_path = os.path.join(pkl_dir, "compare_data.py")
|
|
406
|
-
is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
|
|
407
|
-
|
|
408
|
-
try:
|
|
409
|
-
with FileOpen(template_path, 'r') as ftemp, \
|
|
410
|
-
os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
|
|
411
|
-
code_temp = ftemp.read()
|
|
412
|
-
fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
|
|
413
|
-
except OSError:
|
|
414
|
-
logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
|
|
415
|
-
|
|
416
|
-
logger.info(f"Generate compare script successfully which is {compare_script_path}.")
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
def check_file_valid(file_path):
|
|
420
|
-
if os.path.islink(file_path):
|
|
421
|
-
logger.error('The file path {} is a soft link.'.format(file_path))
|
|
422
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
423
|
-
|
|
424
|
-
if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \
|
|
425
|
-
Const.FILE_NAME_LENGTH:
|
|
426
|
-
logger.error('The file path length exceeds limit.')
|
|
427
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
428
|
-
|
|
429
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)):
|
|
430
|
-
logger.error('The file path {} contains special characters.'.format(file_path))
|
|
431
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
432
|
-
|
|
433
|
-
if os.path.isfile(file_path):
|
|
434
|
-
file_size = os.path.getsize(file_path)
|
|
435
|
-
if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
|
|
436
|
-
logger.error('The file {} size is greater than 1GB.'.format(file_path))
|
|
437
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
438
|
-
if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB:
|
|
439
|
-
logger.error('The file {} size is greater than 10GB.'.format(file_path))
|
|
440
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
def check_path_before_create(path):
|
|
444
|
-
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
445
|
-
Const.FILE_NAME_LENGTH:
|
|
446
|
-
logger.error('The file path length exceeds limit.')
|
|
447
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
448
|
-
|
|
449
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
450
|
-
logger.error('The file path {} contains special characters.'.format(path))
|
|
451
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
def check_inplace_op(prefix):
|
|
455
|
-
if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
|
|
456
|
-
return False
|
|
457
|
-
match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
|
|
458
|
-
op_name = match_op[0] if match_op else None
|
|
459
|
-
return op_name in Const.INPLACE_LIST
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
def md5_find(data):
|
|
463
|
-
for key_op in data:
|
|
464
|
-
for api_info in data[key_op]:
|
|
465
|
-
if isinstance(data[key_op][api_info], list):
|
|
466
|
-
for data_detail in data[key_op][api_info]:
|
|
467
|
-
if data_detail and 'md5' in data_detail:
|
|
468
|
-
return True
|
|
469
|
-
elif 'md5' in data[key_op][api_info]:
|
|
470
|
-
return True
|
|
471
|
-
return False
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
def task_dumppath_get(input_param):
|
|
475
|
-
npu_json_path = input_param.get("npu_json_path", None)
|
|
476
|
-
bench_json_path = input_param.get("bench_json_path", None)
|
|
477
|
-
if not npu_json_path or not bench_json_path:
|
|
478
|
-
logger.error(f"Please check the json path is valid.")
|
|
479
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
480
|
-
with FileOpen(npu_json_path, 'r') as npu_f:
|
|
481
|
-
npu_json_data = json.load(npu_f)
|
|
482
|
-
with FileOpen(bench_json_path, 'r') as bench_f:
|
|
483
|
-
bench_json_data = json.load(bench_f)
|
|
484
|
-
if npu_json_data['task'] != bench_json_data['task']:
|
|
485
|
-
logger.error(f"Please check the dump task is consistent.")
|
|
486
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
487
|
-
if npu_json_data['task'] == Const.TENSOR:
|
|
488
|
-
summary_compare = False
|
|
489
|
-
md5_compare = False
|
|
490
|
-
elif npu_json_data['task'] == Const.STATISTICS:
|
|
491
|
-
md5_compare = md5_find(npu_json_data['data'])
|
|
492
|
-
if md5_compare:
|
|
493
|
-
summary_compare = False
|
|
494
|
-
else:
|
|
495
|
-
summary_compare = True
|
|
496
|
-
else:
|
|
497
|
-
logger.error(f"Compare is not required for overflow_check or free_benchmark.")
|
|
498
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
499
|
-
input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir']
|
|
500
|
-
input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir']
|
|
501
|
-
return summary_compare, md5_compare
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
def get_header_index(header_name, summary_compare=False):
|
|
505
|
-
if summary_compare:
|
|
506
|
-
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
507
|
-
else:
|
|
508
|
-
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
509
|
-
if header_name not in header:
|
|
510
|
-
logger.error(f"{header_name} not in data name")
|
|
511
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
512
|
-
return header.index(header_name)
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
def convert_tuple(data):
|
|
516
|
-
return data if isinstance(data, tuple) else (data, )
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import collections
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import subprocess
|
|
21
|
+
import time
|
|
22
|
+
import json
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
|
+
|
|
25
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path)
|
|
26
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
device = collections.namedtuple('device', ['type', 'index'])
|
|
31
|
+
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CompareException(Exception):
|
|
35
|
+
"""
|
|
36
|
+
Class for Accuracy Compare Exception
|
|
37
|
+
"""
|
|
38
|
+
NONE_ERROR = 0
|
|
39
|
+
INVALID_PATH_ERROR = 1
|
|
40
|
+
OPEN_FILE_ERROR = 2
|
|
41
|
+
CLOSE_FILE_ERROR = 3
|
|
42
|
+
READ_FILE_ERROR = 4
|
|
43
|
+
WRITE_FILE_ERROR = 5
|
|
44
|
+
INVALID_FILE_ERROR = 6
|
|
45
|
+
PERMISSION_ERROR = 7
|
|
46
|
+
INDEX_OUT_OF_BOUNDS_ERROR = 8
|
|
47
|
+
NO_DUMP_FILE_ERROR = 9
|
|
48
|
+
INVALID_DATA_ERROR = 10
|
|
49
|
+
INVALID_PARAM_ERROR = 11
|
|
50
|
+
INVALID_DUMP_RATIO = 12
|
|
51
|
+
INVALID_DUMP_FILE = 13
|
|
52
|
+
UNKNOWN_ERROR = 14
|
|
53
|
+
INVALID_DUMP_MODE = 15
|
|
54
|
+
PARSE_FILE_ERROR = 16
|
|
55
|
+
INVALID_COMPARE_MODE = 17
|
|
56
|
+
OVER_SIZE_FILE_ERROR = 18
|
|
57
|
+
INVALID_SUMMARY_MODE = 19
|
|
58
|
+
INVALID_TASK_ERROR = 20
|
|
59
|
+
DETACH_ERROR = 21
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def __init__(self, code, error_info: str = ""):
|
|
63
|
+
super(CompareException, self).__init__()
|
|
64
|
+
self.code = code
|
|
65
|
+
self.error_info = error_info
|
|
66
|
+
|
|
67
|
+
def __str__(self):
|
|
68
|
+
return self.error_info
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class DumpException(CompareException):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def check_mode_valid(mode, scope=None, api_list=None):
|
|
76
|
+
if scope is None:
|
|
77
|
+
scope = []
|
|
78
|
+
if api_list is None:
|
|
79
|
+
api_list = []
|
|
80
|
+
if not isinstance(scope, list):
|
|
81
|
+
raise ValueError("scope param set invalid, it's must be a list.")
|
|
82
|
+
if not isinstance(api_list, list):
|
|
83
|
+
raise ValueError("api_list param set invalid, it's must be a list.")
|
|
84
|
+
mode_check = {
|
|
85
|
+
Const.ALL: lambda: None,
|
|
86
|
+
Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
|
|
87
|
+
Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
|
|
88
|
+
Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
|
|
89
|
+
Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
|
|
90
|
+
Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
|
|
91
|
+
Const.API_STACK: lambda: None,
|
|
92
|
+
}
|
|
93
|
+
if mode not in Const.DUMP_MODE:
|
|
94
|
+
msg = "Current mode '%s' is not supported. Please use the field in %s" % \
|
|
95
|
+
(mode, Const.DUMP_MODE)
|
|
96
|
+
raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
|
|
97
|
+
|
|
98
|
+
if mode_check.get(mode)() is not None:
|
|
99
|
+
raise mode_check.get(mode)()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def check_switch_valid(switch):
|
|
103
|
+
if switch not in ["ON", "OFF"]:
|
|
104
|
+
logger.error("Please set switch with 'ON' or 'OFF'.")
|
|
105
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def check_dump_mode_valid(dump_mode):
|
|
109
|
+
if not isinstance(dump_mode, list):
|
|
110
|
+
logger.warning("Please set dump_mode as a list.")
|
|
111
|
+
dump_mode = [dump_mode]
|
|
112
|
+
if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
|
|
113
|
+
raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
|
|
114
|
+
if 'input' not in dump_mode and 'output' not in dump_mode:
|
|
115
|
+
dump_mode.extend(['input', 'output'])
|
|
116
|
+
if 'forward' not in dump_mode and 'backward' not in dump_mode:
|
|
117
|
+
dump_mode.extend(['forward', 'backward'])
|
|
118
|
+
if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
|
|
119
|
+
return ["forward", "backward", "input", "output"]
|
|
120
|
+
return dump_mode
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def check_summary_mode_valid(summary_mode):
|
|
124
|
+
if summary_mode not in Const.SUMMARY_MODE:
|
|
125
|
+
msg = "The summary_mode is not valid"
|
|
126
|
+
raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def check_summary_only_valid(summary_only):
|
|
130
|
+
if not isinstance(summary_only, bool):
|
|
131
|
+
logger.error("Params summary_only only support True or False.")
|
|
132
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
133
|
+
return summary_only
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
|
|
137
|
+
if not (isinstance(input_param, dict) and isinstance(output_path, str)):
|
|
138
|
+
logger.error("Invalid input parameters")
|
|
139
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
140
|
+
|
|
141
|
+
check_file_or_directory_path(input_param.get("npu_json_path"), False)
|
|
142
|
+
check_file_or_directory_path(input_param.get("bench_json_path"), False)
|
|
143
|
+
check_file_or_directory_path(input_param.get("stack_json_path"), False)
|
|
144
|
+
if not summary_compare and not md5_compare:
|
|
145
|
+
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
146
|
+
check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
|
|
147
|
+
check_file_or_directory_path(output_path, True)
|
|
148
|
+
|
|
149
|
+
with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
|
|
150
|
+
FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
|
|
151
|
+
FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
|
|
152
|
+
check_json_file(input_param, npu_json, bench_json, stack_json)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
|
|
157
|
+
if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
|
|
158
|
+
logger.error("Invalid input parameters which should be only bool type.")
|
|
159
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def is_starts_with(string, prefix_list):
|
|
163
|
+
return any(string.startswith(prefix) for prefix in prefix_list)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _check_json(json_file_handle, file_name):
|
|
167
|
+
tensor_line = json_file_handle.readline()
|
|
168
|
+
if not tensor_line:
|
|
169
|
+
logger.error("dump file {} have empty line!".format(file_name))
|
|
170
|
+
raise CompareException(CompareException.INVALID_DUMP_FILE)
|
|
171
|
+
json_file_handle.seek(0, 0)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def check_json_file(input_param, npu_json, bench_json, stack_json):
|
|
175
|
+
_check_json(npu_json, input_param.get("npu_json_path"))
|
|
176
|
+
_check_json(bench_json, input_param.get("bench_json_path"))
|
|
177
|
+
_check_json(stack_json, input_param.get("stack_json_path"))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def check_regex_prefix_format_valid(prefix):
|
|
181
|
+
"""
|
|
182
|
+
validate the format of the regex prefix
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
prefix (str): The prefix string to validate.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
no returns
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
|
|
192
|
+
the given pattern Const.REGEX_PREFIX_PATTERN
|
|
193
|
+
"""
|
|
194
|
+
if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
|
|
195
|
+
raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
|
|
196
|
+
f"is {len(prefix)}")
|
|
197
|
+
if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
|
|
198
|
+
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_dump_data_path(dump_dir):
|
|
202
|
+
"""
|
|
203
|
+
Function Description:
|
|
204
|
+
traverse directories and obtain the absolute path of dump data
|
|
205
|
+
Parameter:
|
|
206
|
+
dump_dir: dump data directory
|
|
207
|
+
Return Value:
|
|
208
|
+
dump data path,file is exist or file is not exist
|
|
209
|
+
"""
|
|
210
|
+
dump_data_path = None
|
|
211
|
+
file_is_exist = False
|
|
212
|
+
|
|
213
|
+
check_file_or_directory_path(dump_dir, True)
|
|
214
|
+
for dir_path, _, files in os.walk(dump_dir):
|
|
215
|
+
if len(files) != 0:
|
|
216
|
+
dump_data_path = dir_path
|
|
217
|
+
file_is_exist = True
|
|
218
|
+
break
|
|
219
|
+
dump_data_path = dir_path
|
|
220
|
+
return dump_data_path, file_is_exist
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def execute_command(cmd):
|
|
224
|
+
"""
|
|
225
|
+
Function Description:
|
|
226
|
+
run the following command
|
|
227
|
+
Parameter:
|
|
228
|
+
cmd: command
|
|
229
|
+
Exception Description:
|
|
230
|
+
when invalid command throw exception
|
|
231
|
+
"""
|
|
232
|
+
logger.info('Execute command:%s' % cmd)
|
|
233
|
+
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
234
|
+
while process.poll() is None:
|
|
235
|
+
line = process.stdout.readline()
|
|
236
|
+
line = line.strip()
|
|
237
|
+
if line:
|
|
238
|
+
print(line)
|
|
239
|
+
if process.returncode != 0:
|
|
240
|
+
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
241
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def parse_value_by_comma(value):
|
|
245
|
+
"""
|
|
246
|
+
parse value by comma, like '1,2,4,8'
|
|
247
|
+
"""
|
|
248
|
+
value_list = []
|
|
249
|
+
value_str_list = value.split(Const.COMMA)
|
|
250
|
+
for value_str in value_str_list:
|
|
251
|
+
value_str = value_str.strip()
|
|
252
|
+
if value_str.isdigit() or value_str == '-1':
|
|
253
|
+
value_list.append(int(value_str))
|
|
254
|
+
else:
|
|
255
|
+
logger.error("please check your input shape.")
|
|
256
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
257
|
+
return value_list
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def add_time_as_suffix(name):
|
|
261
|
+
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def add_time_with_xlsx(name):
|
|
265
|
+
return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def get_time():
|
|
269
|
+
return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def format_value(value):
|
|
273
|
+
return float('{:.12f}'.format(value))
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def check_seed_all(seed, mode):
|
|
277
|
+
if isinstance(seed, int):
|
|
278
|
+
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
279
|
+
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
280
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
281
|
+
else:
|
|
282
|
+
logger.error(f"Seed must be integer.")
|
|
283
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
284
|
+
if not isinstance(mode, bool):
|
|
285
|
+
logger.error(f"seed_all mode must be bool.")
|
|
286
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def get_process_rank(model):
|
|
290
|
+
logger.info("Rank id is not provided. Trying to get the rank id of the model.")
|
|
291
|
+
try:
|
|
292
|
+
local_device = next(model.parameters()).device
|
|
293
|
+
except StopIteration:
|
|
294
|
+
logger.warning('There is no parameter in the model. Fail to get rank id.')
|
|
295
|
+
return 0, False
|
|
296
|
+
if local_device.type == 'cpu':
|
|
297
|
+
logger.warning("Warning: the debugger is unable to get the rank id. "
|
|
298
|
+
"This may cause the dumpped data to be corrupted in the "
|
|
299
|
+
"case of distributed training. (You may ignore this if you are using only one card.) "
|
|
300
|
+
"Transfer the model to npu or gpu before register_hook() to avoid this warning.")
|
|
301
|
+
return 0, False
|
|
302
|
+
else:
|
|
303
|
+
return local_device.index, True
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
|
|
307
|
+
template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
|
|
308
|
+
pkl_dir = os.path.dirname(pkl_file_path)
|
|
309
|
+
compare_script_path = os.path.join(pkl_dir, "compare_data.py")
|
|
310
|
+
is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
with FileOpen(template_path, 'r') as ftemp, \
|
|
314
|
+
os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
|
|
315
|
+
code_temp = ftemp.read()
|
|
316
|
+
fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
|
|
317
|
+
except OSError:
|
|
318
|
+
logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
|
|
319
|
+
|
|
320
|
+
logger.info(f"Generate compare script successfully which is {compare_script_path}.")
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def check_inplace_op(prefix):
|
|
324
|
+
if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
|
|
325
|
+
return False
|
|
326
|
+
match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
|
|
327
|
+
op_name = match_op[0] if match_op else None
|
|
328
|
+
return op_name in Const.INPLACE_LIST
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def md5_find(data):
|
|
332
|
+
for key_op in data:
|
|
333
|
+
for api_info in data[key_op]:
|
|
334
|
+
if isinstance(data[key_op][api_info], list):
|
|
335
|
+
for data_detail in data[key_op][api_info]:
|
|
336
|
+
if data_detail and 'md5' in data_detail:
|
|
337
|
+
return True
|
|
338
|
+
elif 'md5' in data[key_op][api_info]:
|
|
339
|
+
return True
|
|
340
|
+
return False
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def task_dumppath_get(input_param):
|
|
344
|
+
npu_path = input_param.get("npu_json_path", None)
|
|
345
|
+
bench_path = input_param.get("bench_json_path", None)
|
|
346
|
+
if not npu_path or not bench_path:
|
|
347
|
+
logger.error(f"Please check the json path is valid.")
|
|
348
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
349
|
+
with FileOpen(npu_path, 'r') as npu_f:
|
|
350
|
+
npu_json_data = json.load(npu_f)
|
|
351
|
+
with FileOpen(bench_path, 'r') as bench_f:
|
|
352
|
+
bench_json_data = json.load(bench_f)
|
|
353
|
+
if npu_json_data['task'] != bench_json_data['task']:
|
|
354
|
+
logger.error(f"Please check the dump task is consistent.")
|
|
355
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
356
|
+
if npu_json_data['task'] == Const.TENSOR:
|
|
357
|
+
summary_compare = False
|
|
358
|
+
md5_compare = False
|
|
359
|
+
elif npu_json_data['task'] == Const.STATISTICS:
|
|
360
|
+
md5_compare = md5_find(npu_json_data['data'])
|
|
361
|
+
if md5_compare:
|
|
362
|
+
summary_compare = False
|
|
363
|
+
else:
|
|
364
|
+
summary_compare = True
|
|
365
|
+
else:
|
|
366
|
+
logger.error(f"Compare is not required for overflow_check or free_benchmark.")
|
|
367
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
368
|
+
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
369
|
+
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
370
|
+
return summary_compare, md5_compare
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def get_header_index(header_name, summary_compare=False):
|
|
374
|
+
if summary_compare:
|
|
375
|
+
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
376
|
+
else:
|
|
377
|
+
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
378
|
+
if header_name not in header:
|
|
379
|
+
logger.error(f"{header_name} not in data name")
|
|
380
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
381
|
+
return header.index(header_name)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def convert_tuple(data):
|
|
385
|
+
return data if isinstance(data, tuple) else (data, )
|