mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,39 +1,39 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
7
|
-
from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
|
|
8
|
-
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class CheckerHandler(FuzzHandler):
|
|
12
|
-
def other_compare(self, data_params: DataParams) -> bool:
|
|
13
|
-
is_consistent = SingleCompare().compare_seq(
|
|
14
|
-
data_params.original_result, data_params.perturbed_result
|
|
15
|
-
)
|
|
16
|
-
if not is_consistent:
|
|
17
|
-
self.unequal_rows.append(
|
|
18
|
-
make_unequal_row(data_params, self.params)
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
def get_threshold(self, dtype):
|
|
22
|
-
return self._get_default_threshold(dtype)
|
|
23
|
-
|
|
24
|
-
def handle(self, data_params: DataParams) -> Any:
|
|
25
|
-
if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
|
|
26
|
-
data_params.perturbed_result
|
|
27
|
-
):
|
|
28
|
-
return data_params.original_result
|
|
29
|
-
try:
|
|
30
|
-
if self.params.fuzz_device == DeviceType.NPU:
|
|
31
|
-
self.cmp_output_npu(data_params)
|
|
32
|
-
else:
|
|
33
|
-
self.other_compare(data_params)
|
|
34
|
-
except Exception as e:
|
|
35
|
-
logger.warning_on_rank_0(
|
|
36
|
-
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
37
|
-
f"when campare the result exception raise {e}"
|
|
38
|
-
)
|
|
39
|
-
return data_params.original_result
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
7
|
+
from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
|
|
8
|
+
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CheckerHandler(FuzzHandler):
|
|
12
|
+
def other_compare(self, data_params: DataParams) -> bool:
|
|
13
|
+
is_consistent = SingleCompare().compare_seq(
|
|
14
|
+
data_params.original_result, data_params.perturbed_result
|
|
15
|
+
)
|
|
16
|
+
if not is_consistent:
|
|
17
|
+
self.unequal_rows.append(
|
|
18
|
+
make_unequal_row(data_params, self.params)
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def get_threshold(self, dtype):
|
|
22
|
+
return self._get_default_threshold(dtype)
|
|
23
|
+
|
|
24
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
25
|
+
if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
|
|
26
|
+
data_params.perturbed_result
|
|
27
|
+
):
|
|
28
|
+
return data_params.original_result
|
|
29
|
+
try:
|
|
30
|
+
if self.params.fuzz_device == DeviceType.NPU:
|
|
31
|
+
self.cmp_output_npu(data_params)
|
|
32
|
+
else:
|
|
33
|
+
self.other_compare(data_params)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.warning_on_rank_0(
|
|
36
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
37
|
+
f"when campare the result exception raise {e}"
|
|
38
|
+
)
|
|
39
|
+
return data_params.original_result
|
|
@@ -1,24 +1,24 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
5
|
-
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
6
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class FixHandler(FuzzHandler):
|
|
10
|
-
|
|
11
|
-
def get_threshold(self, dtype):
|
|
12
|
-
return self._get_default_threshold(dtype)
|
|
13
|
-
|
|
14
|
-
def handle(self, data_params: DataParams) -> Any:
|
|
15
|
-
try:
|
|
16
|
-
return Tools.convert_fuzz_output_to_origin(
|
|
17
|
-
data_params.original_result, data_params.perturbed_result
|
|
18
|
-
)
|
|
19
|
-
except Exception as e:
|
|
20
|
-
logger.warning_on_rank_0(
|
|
21
|
-
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
22
|
-
f"Fix output failed. "
|
|
23
|
-
)
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
5
|
+
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
6
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FixHandler(FuzzHandler):
|
|
10
|
+
|
|
11
|
+
def get_threshold(self, dtype):
|
|
12
|
+
return self._get_default_threshold(dtype)
|
|
13
|
+
|
|
14
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
15
|
+
try:
|
|
16
|
+
return Tools.convert_fuzz_output_to_origin(
|
|
17
|
+
data_params.original_result, data_params.perturbed_result
|
|
18
|
+
)
|
|
19
|
+
except Exception as e:
|
|
20
|
+
logger.warning_on_rank_0(
|
|
21
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
22
|
+
f"Fix output failed. "
|
|
23
|
+
)
|
|
24
24
|
return data_params.original_result
|
|
@@ -1,31 +1,30 @@
|
|
|
1
|
-
from msprobe.pytorch.free_benchmark import FreeBenchmarkException
|
|
2
|
-
from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
|
|
3
|
-
from msprobe.pytorch.free_benchmark.common.enums import HandlerType
|
|
4
|
-
from msprobe.pytorch.free_benchmark.common.params import HandlerParams
|
|
5
|
-
from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
|
|
6
|
-
from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
|
|
7
|
-
from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class FuzzHandlerFactory:
|
|
11
|
-
|
|
12
|
-
result_handlers = {
|
|
13
|
-
HandlerType.CHECK: CheckerHandler,
|
|
14
|
-
HandlerType.FIX: FixHandler,
|
|
15
|
-
HandlerType.PREHEAT: PreheatHandler,
|
|
16
|
-
}
|
|
17
|
-
|
|
18
|
-
@staticmethod
|
|
19
|
-
def create(params: HandlerParams):
|
|
20
|
-
if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
|
|
21
|
-
if not if_preheat:
|
|
22
|
-
handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
|
|
23
|
-
else:
|
|
24
|
-
handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
return handler(params)
|
|
1
|
+
from msprobe.pytorch.free_benchmark import FreeBenchmarkException
|
|
2
|
+
from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
|
|
3
|
+
from msprobe.pytorch.free_benchmark.common.enums import HandlerType
|
|
4
|
+
from msprobe.pytorch.free_benchmark.common.params import HandlerParams
|
|
5
|
+
from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
|
|
6
|
+
from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
|
|
7
|
+
from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FuzzHandlerFactory:
|
|
11
|
+
|
|
12
|
+
result_handlers = {
|
|
13
|
+
HandlerType.CHECK: CheckerHandler,
|
|
14
|
+
HandlerType.FIX: FixHandler,
|
|
15
|
+
HandlerType.PREHEAT: PreheatHandler,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def create(params: HandlerParams):
|
|
20
|
+
if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
|
|
21
|
+
if not if_preheat:
|
|
22
|
+
handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
|
|
23
|
+
else:
|
|
24
|
+
handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
|
|
25
|
+
if not handler:
|
|
26
|
+
raise FreeBenchmarkException(
|
|
27
|
+
FreeBenchmarkException.UnsupportedType,
|
|
28
|
+
f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
|
|
29
|
+
)
|
|
30
|
+
return handler(params)
|
|
@@ -1,170 +1,170 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
5
|
-
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
6
|
-
from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
|
|
7
|
-
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
8
|
-
from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
|
|
9
|
-
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
10
|
-
from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
|
|
11
|
-
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class PreheatHandler(FuzzHandler):
|
|
15
|
-
|
|
16
|
-
def __init__(self, params: HandlerParams) -> None:
|
|
17
|
-
super().__init__(params)
|
|
18
|
-
self.pure_name = Tools.get_pure_api_name(self.params.api_name)
|
|
19
|
-
|
|
20
|
-
def get_threshold(self, dtype):
|
|
21
|
-
return preheat_counter.get_api_thd(self.pure_name, dtype)
|
|
22
|
-
|
|
23
|
-
def compare_npu_and_cpu(self, data_params: DataParams):
|
|
24
|
-
args = Tools.convert_device_and_dtype(
|
|
25
|
-
data_params.args, DeviceType.CPU, change_dtype=True
|
|
26
|
-
)
|
|
27
|
-
kwargs = Tools.convert_device_and_dtype(
|
|
28
|
-
data_params.kwargs, DeviceType.CPU, change_dtype=True
|
|
29
|
-
)
|
|
30
|
-
cpu_result = data_params.origin_func(*args, **kwargs)
|
|
31
|
-
return SingleCompare().compare_seq(data_params.original_result, cpu_result)
|
|
32
|
-
|
|
33
|
-
def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
|
|
34
|
-
# 存储当前step所有输出比值和对应npu\cpu比对结果
|
|
35
|
-
preheat_counter.update_preheat_record(
|
|
36
|
-
self.pure_name,
|
|
37
|
-
first_dtype,
|
|
38
|
-
(max_fuzz_ratio, cpu_consistent),
|
|
39
|
-
)
|
|
40
|
-
if self._need_adjust_threshold():
|
|
41
|
-
self._adjust_threshold()
|
|
42
|
-
|
|
43
|
-
def handle(self, data_params: DataParams) -> Any:
|
|
44
|
-
|
|
45
|
-
if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
|
|
46
|
-
data_params.perturbed_result
|
|
47
|
-
):
|
|
48
|
-
return data_params.original_result
|
|
49
|
-
|
|
50
|
-
if self.params.step == 0:
|
|
51
|
-
preheat_counter.add_one_step_used_api(self.pure_name)
|
|
52
|
-
return data_params.original_result
|
|
53
|
-
|
|
54
|
-
# 如果当前api,step需要预热
|
|
55
|
-
npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
|
|
56
|
-
data_params.is_consistent = npu_consistent
|
|
57
|
-
|
|
58
|
-
preheat_counter.check_step(self.params.step)
|
|
59
|
-
|
|
60
|
-
if self.params.preheat_config.get("preheat_step") <= self.params.step:
|
|
61
|
-
return data_params.original_result
|
|
62
|
-
|
|
63
|
-
if not data_params.grad_unequal_flag:
|
|
64
|
-
data_params.grad_unequal_flag = True
|
|
65
|
-
data_params.is_consistent = False
|
|
66
|
-
return data_params.original_result
|
|
67
|
-
preheat_counter.add_api_called_time(self.pure_name)
|
|
68
|
-
|
|
69
|
-
if not self._is_take_a_sample():
|
|
70
|
-
return data_params.original_result
|
|
71
|
-
|
|
72
|
-
cpu_consistent = True
|
|
73
|
-
try:
|
|
74
|
-
cpu_consistent = self.compare_npu_and_cpu(data_params)
|
|
75
|
-
except Exception as e:
|
|
76
|
-
logger.warning_on_rank_0(
|
|
77
|
-
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
78
|
-
f"when campare to cpu exception raise {e}"
|
|
79
|
-
)
|
|
80
|
-
try:
|
|
81
|
-
first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
|
|
82
|
-
except RuntimeError:
|
|
83
|
-
logger.warning_on_rank_0(
|
|
84
|
-
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
85
|
-
f"the output sequence does not contain tensors."
|
|
86
|
-
)
|
|
87
|
-
if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
|
|
88
|
-
self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
|
|
89
|
-
|
|
90
|
-
return data_params.original_result
|
|
91
|
-
|
|
92
|
-
def _is_take_a_sample(self) -> bool:
|
|
93
|
-
need_sample_set = self._get_need_sample_set()
|
|
94
|
-
curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
|
|
95
|
-
res = curr_called_seq in need_sample_set
|
|
96
|
-
if res:
|
|
97
|
-
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
98
|
-
logger.info_on_rank_0(
|
|
99
|
-
f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
|
|
100
|
-
f"api_name {self.params.api_name}, "
|
|
101
|
-
f"curr_called_seq: {curr_called_seq}/{total_count}"
|
|
102
|
-
)
|
|
103
|
-
preheat_counter.add_api_sample_time(self.pure_name)
|
|
104
|
-
return res
|
|
105
|
-
|
|
106
|
-
def _get_sample_count_per_step(self) -> set:
|
|
107
|
-
"""
|
|
108
|
-
每一个step中应该采集的样本数
|
|
109
|
-
"""
|
|
110
|
-
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
111
|
-
preheat_step = self.params.preheat_config.get("preheat_step")
|
|
112
|
-
max_sample = self.params.preheat_config.get("max_sample")
|
|
113
|
-
return min(math.ceil(total_count / preheat_step), max_sample)
|
|
114
|
-
|
|
115
|
-
def _get_need_sample_set(self):
|
|
116
|
-
"""
|
|
117
|
-
需要采集的api集合
|
|
118
|
-
"""
|
|
119
|
-
# 每一步样本数
|
|
120
|
-
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
121
|
-
sample_count_per_step = self._get_sample_count_per_step()
|
|
122
|
-
need_sample_set = set()
|
|
123
|
-
prehead_step = self.params.preheat_config.get("preheat_step")
|
|
124
|
-
for i in range(1, sample_count_per_step + 1):
|
|
125
|
-
count = (prehead_step * (i - 1) + self.params.step) % total_count
|
|
126
|
-
if count == 0:
|
|
127
|
-
count = total_count
|
|
128
|
-
need_sample_set.add(count)
|
|
129
|
-
return need_sample_set
|
|
130
|
-
|
|
131
|
-
def _need_adjust_threshold(self) -> bool:
|
|
132
|
-
sample_count_per_step = self._get_sample_count_per_step()
|
|
133
|
-
sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
|
|
134
|
-
res = sampled_time >= sample_count_per_step
|
|
135
|
-
return res
|
|
136
|
-
|
|
137
|
-
def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
|
|
138
|
-
con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
|
|
139
|
-
incon_ratio = [
|
|
140
|
-
ratio for ratio, is_consistent in compare_result if not is_consistent
|
|
141
|
-
]
|
|
142
|
-
old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
|
|
143
|
-
new_thd = old_thd
|
|
144
|
-
# 正例负例都存在
|
|
145
|
-
if con_ratio and incon_ratio:
|
|
146
|
-
if min(incon_ratio) > max(con_ratio):
|
|
147
|
-
new_thd = min(min(incon_ratio), old_thd)
|
|
148
|
-
preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
|
|
149
|
-
elif con_ratio:
|
|
150
|
-
# 存在漏报
|
|
151
|
-
if max(con_ratio) > old_thd:
|
|
152
|
-
new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
|
|
153
|
-
else:
|
|
154
|
-
new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
|
|
155
|
-
else:
|
|
156
|
-
new_thd = min(min(incon_ratio), old_thd)
|
|
157
|
-
preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
|
|
158
|
-
return new_thd
|
|
159
|
-
|
|
160
|
-
def _adjust_threshold(self):
|
|
161
|
-
for dtype_str, compare_result in preheat_counter.preheat_record[
|
|
162
|
-
self.pure_name
|
|
163
|
-
].items():
|
|
164
|
-
new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
|
|
165
|
-
threshold = self._get_default_threshold(
|
|
166
|
-
preheat_counter.dtype_map.get(dtype_str)
|
|
167
|
-
)
|
|
168
|
-
preheat_counter.update_api_thd(
|
|
169
|
-
self.pure_name, dtype_str, new_thd, threshold
|
|
170
|
-
)
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
5
|
+
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
6
|
+
from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
|
|
7
|
+
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
8
|
+
from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
|
|
9
|
+
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
10
|
+
from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
|
|
11
|
+
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PreheatHandler(FuzzHandler):
|
|
15
|
+
|
|
16
|
+
def __init__(self, params: HandlerParams) -> None:
|
|
17
|
+
super().__init__(params)
|
|
18
|
+
self.pure_name = Tools.get_pure_api_name(self.params.api_name)
|
|
19
|
+
|
|
20
|
+
def get_threshold(self, dtype):
|
|
21
|
+
return preheat_counter.get_api_thd(self.pure_name, dtype)
|
|
22
|
+
|
|
23
|
+
def compare_npu_and_cpu(self, data_params: DataParams):
|
|
24
|
+
args = Tools.convert_device_and_dtype(
|
|
25
|
+
data_params.args, DeviceType.CPU, change_dtype=True
|
|
26
|
+
)
|
|
27
|
+
kwargs = Tools.convert_device_and_dtype(
|
|
28
|
+
data_params.kwargs, DeviceType.CPU, change_dtype=True
|
|
29
|
+
)
|
|
30
|
+
cpu_result = data_params.origin_func(*args, **kwargs)
|
|
31
|
+
return SingleCompare().compare_seq(data_params.original_result, cpu_result)
|
|
32
|
+
|
|
33
|
+
def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
|
|
34
|
+
# 存储当前step所有输出比值和对应npu\cpu比对结果
|
|
35
|
+
preheat_counter.update_preheat_record(
|
|
36
|
+
self.pure_name,
|
|
37
|
+
first_dtype,
|
|
38
|
+
(max_fuzz_ratio, cpu_consistent),
|
|
39
|
+
)
|
|
40
|
+
if self._need_adjust_threshold():
|
|
41
|
+
self._adjust_threshold()
|
|
42
|
+
|
|
43
|
+
def handle(self, data_params: DataParams) -> Any:
|
|
44
|
+
|
|
45
|
+
if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
|
|
46
|
+
data_params.perturbed_result
|
|
47
|
+
):
|
|
48
|
+
return data_params.original_result
|
|
49
|
+
|
|
50
|
+
if self.params.step == 0:
|
|
51
|
+
preheat_counter.add_one_step_used_api(self.pure_name)
|
|
52
|
+
return data_params.original_result
|
|
53
|
+
|
|
54
|
+
# 如果当前api,step需要预热
|
|
55
|
+
npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
|
|
56
|
+
data_params.is_consistent = npu_consistent
|
|
57
|
+
|
|
58
|
+
preheat_counter.check_step(self.params.step)
|
|
59
|
+
|
|
60
|
+
if self.params.preheat_config.get("preheat_step") <= self.params.step:
|
|
61
|
+
return data_params.original_result
|
|
62
|
+
|
|
63
|
+
if not data_params.grad_unequal_flag:
|
|
64
|
+
data_params.grad_unequal_flag = True
|
|
65
|
+
data_params.is_consistent = False
|
|
66
|
+
return data_params.original_result
|
|
67
|
+
preheat_counter.add_api_called_time(self.pure_name)
|
|
68
|
+
|
|
69
|
+
if not self._is_take_a_sample():
|
|
70
|
+
return data_params.original_result
|
|
71
|
+
|
|
72
|
+
cpu_consistent = True
|
|
73
|
+
try:
|
|
74
|
+
cpu_consistent = self.compare_npu_and_cpu(data_params)
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.warning_on_rank_0(
|
|
77
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
78
|
+
f"when campare to cpu exception raise {e}"
|
|
79
|
+
)
|
|
80
|
+
try:
|
|
81
|
+
first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
|
|
82
|
+
except RuntimeError:
|
|
83
|
+
logger.warning_on_rank_0(
|
|
84
|
+
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
85
|
+
f"the output sequence does not contain tensors."
|
|
86
|
+
)
|
|
87
|
+
if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
|
|
88
|
+
self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
|
|
89
|
+
|
|
90
|
+
return data_params.original_result
|
|
91
|
+
|
|
92
|
+
def _is_take_a_sample(self) -> bool:
|
|
93
|
+
need_sample_set = self._get_need_sample_set()
|
|
94
|
+
curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
|
|
95
|
+
res = curr_called_seq in need_sample_set
|
|
96
|
+
if res:
|
|
97
|
+
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
98
|
+
logger.info_on_rank_0(
|
|
99
|
+
f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
|
|
100
|
+
f"api_name {self.params.api_name}, "
|
|
101
|
+
f"curr_called_seq: {curr_called_seq}/{total_count}"
|
|
102
|
+
)
|
|
103
|
+
preheat_counter.add_api_sample_time(self.pure_name)
|
|
104
|
+
return res
|
|
105
|
+
|
|
106
|
+
def _get_sample_count_per_step(self) -> set:
|
|
107
|
+
"""
|
|
108
|
+
每一个step中应该采集的样本数
|
|
109
|
+
"""
|
|
110
|
+
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
111
|
+
preheat_step = self.params.preheat_config.get("preheat_step")
|
|
112
|
+
max_sample = self.params.preheat_config.get("max_sample")
|
|
113
|
+
return min(math.ceil(total_count / preheat_step), max_sample)
|
|
114
|
+
|
|
115
|
+
def _get_need_sample_set(self):
|
|
116
|
+
"""
|
|
117
|
+
需要采集的api集合
|
|
118
|
+
"""
|
|
119
|
+
# 每一步样本数
|
|
120
|
+
total_count = preheat_counter.get_one_step_used_api(self.pure_name)
|
|
121
|
+
sample_count_per_step = self._get_sample_count_per_step()
|
|
122
|
+
need_sample_set = set()
|
|
123
|
+
prehead_step = self.params.preheat_config.get("preheat_step")
|
|
124
|
+
for i in range(1, sample_count_per_step + 1):
|
|
125
|
+
count = (prehead_step * (i - 1) + self.params.step) % total_count
|
|
126
|
+
if count == 0:
|
|
127
|
+
count = total_count
|
|
128
|
+
need_sample_set.add(count)
|
|
129
|
+
return need_sample_set
|
|
130
|
+
|
|
131
|
+
def _need_adjust_threshold(self) -> bool:
|
|
132
|
+
sample_count_per_step = self._get_sample_count_per_step()
|
|
133
|
+
sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
|
|
134
|
+
res = sampled_time >= sample_count_per_step
|
|
135
|
+
return res
|
|
136
|
+
|
|
137
|
+
def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
|
|
138
|
+
con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
|
|
139
|
+
incon_ratio = [
|
|
140
|
+
ratio for ratio, is_consistent in compare_result if not is_consistent
|
|
141
|
+
]
|
|
142
|
+
old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
|
|
143
|
+
new_thd = old_thd
|
|
144
|
+
# 正例负例都存在
|
|
145
|
+
if con_ratio and incon_ratio:
|
|
146
|
+
if min(incon_ratio) > max(con_ratio):
|
|
147
|
+
new_thd = min(min(incon_ratio), old_thd)
|
|
148
|
+
preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
|
|
149
|
+
elif con_ratio:
|
|
150
|
+
# 存在漏报
|
|
151
|
+
if max(con_ratio) > old_thd:
|
|
152
|
+
new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
|
|
153
|
+
else:
|
|
154
|
+
new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
|
|
155
|
+
else:
|
|
156
|
+
new_thd = min(min(incon_ratio), old_thd)
|
|
157
|
+
preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
|
|
158
|
+
return new_thd
|
|
159
|
+
|
|
160
|
+
def _adjust_threshold(self):
|
|
161
|
+
for dtype_str, compare_result in preheat_counter.preheat_record[
|
|
162
|
+
self.pure_name
|
|
163
|
+
].items():
|
|
164
|
+
new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
|
|
165
|
+
threshold = self._get_default_threshold(
|
|
166
|
+
preheat_counter.dtype_map.get(dtype_str)
|
|
167
|
+
)
|
|
168
|
+
preheat_counter.update_api_thd(
|
|
169
|
+
self.pure_name, dtype_str, new_thd, threshold
|
|
170
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from msprobe.pytorch.common.utils import logger
|
|
2
|
+
from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
|
|
3
|
+
from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
|
|
4
|
+
npu_confusion_transpose_backward
|
|
5
|
+
from msprobe.pytorch.bench_functions.fast_gelu import npu_fast_gelu, npu_fast_gelu_backward
|
|
6
|
+
from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
|
|
7
|
+
from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
|
|
8
|
+
from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
|
|
9
|
+
from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad, \
|
|
10
|
+
gpu_fusion_attention
|
|
11
|
+
from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
|
|
12
|
+
from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
|
|
13
|
+
from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
|
|
14
|
+
npu_scaled_masked_softmax_backward
|
|
15
|
+
from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Register(dict):
|
|
19
|
+
def __init__(self, *args, **kwargs):
|
|
20
|
+
super(Register, self).__init__(*args, **kwargs)
|
|
21
|
+
self._dict = {}
|
|
22
|
+
|
|
23
|
+
def __call__(self, target_func_list):
|
|
24
|
+
for target in target_func_list:
|
|
25
|
+
self.register(target)
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
def __setitem__(self, key, value):
|
|
29
|
+
self._dict[key] = value
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, key):
|
|
32
|
+
return self._dict[key]
|
|
33
|
+
|
|
34
|
+
def __contains__(self, key):
|
|
35
|
+
return key in self._dict
|
|
36
|
+
|
|
37
|
+
def __str__(self):
|
|
38
|
+
return str(self._dict)
|
|
39
|
+
|
|
40
|
+
def keys(self):
|
|
41
|
+
return self._dict.keys()
|
|
42
|
+
|
|
43
|
+
def values(self):
|
|
44
|
+
return self._dict.values()
|
|
45
|
+
|
|
46
|
+
def items(self):
|
|
47
|
+
return self._dict.items()
|
|
48
|
+
|
|
49
|
+
def register(self, target):
|
|
50
|
+
|
|
51
|
+
def add_register_item(key, value):
|
|
52
|
+
if key in self._dict:
|
|
53
|
+
logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
|
|
54
|
+
self[key] = value
|
|
55
|
+
return value
|
|
56
|
+
|
|
57
|
+
if callable(target):
|
|
58
|
+
return add_register_item(target.__name__, target)
|
|
59
|
+
else:
|
|
60
|
+
raise Exception(f"The func {target} is not callable.")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# register for npu custom bench functions
|
|
64
|
+
npu_custom_functions = Register()
|
|
65
|
+
npu_custom_functions([
|
|
66
|
+
npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
|
|
67
|
+
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
|
|
68
|
+
])
|
|
69
|
+
|
|
70
|
+
# register for npu custom backward bench functions
|
|
71
|
+
npu_custom_grad_functions = Register()
|
|
72
|
+
npu_custom_grad_functions([
|
|
73
|
+
npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
|
|
74
|
+
npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
|
|
75
|
+
npu_swiglu_backward
|
|
76
|
+
])
|