mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -30
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -182
- msprobe/__init__.py +1 -0
- msprobe/{config/config.json → config.json} +49 -27
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +124 -124
- msprobe/{pytorch → core}/advisor/advisor_const.py +59 -59
- msprobe/{pytorch → core}/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -241
- msprobe/core/common/exceptions.py +100 -88
- msprobe/core/common/{file_check.py → file_utils.py} +478 -265
- msprobe/core/common/log.py +76 -55
- msprobe/core/common/utils.py +385 -516
- msprobe/core/common_config.py +85 -58
- msprobe/core/compare/acc_compare.py +300 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +223 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +295 -244
- msprobe/core/compare/utils.py +430 -0
- msprobe/core/data_dump/data_collector.py +154 -140
- msprobe/core/data_dump/data_processor/base.py +314 -245
- msprobe/core/data_dump/data_processor/factory.py +59 -61
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -346
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +171 -0
- msprobe/core/grad_probe/utils.py +64 -0
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/docs/17.grad_probe.md +207 -0
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/grad_probe_image-1.png +0 -0
- msprobe/docs/img/grad_probe_image-2.png +0 -0
- msprobe/docs/img/grad_probe_image-3.png +0 -0
- msprobe/docs/img/grad_probe_image-4.png +0 -0
- msprobe/docs/img/grad_probe_image.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +255 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +156 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +239 -0
- msprobe/mindspore/api_accuracy_checker/main.py +9 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +80 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +106 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +81 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +219 -0
- msprobe/mindspore/compare/ms_graph_compare.py +348 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +66 -51
- msprobe/mindspore/debugger/precision_debugger.py +126 -32
- msprobe/mindspore/dump/dump_tool_factory.py +35 -38
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -0
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +90 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +94 -0
- msprobe/mindspore/grad_probe/utils.py +30 -0
- msprobe/mindspore/ms_config.py +128 -78
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -32
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +378 -0
- msprobe/mindspore/task_handler_factory.py +24 -21
- msprobe/msprobe.py +105 -67
- msprobe/pytorch/__init__.py +4 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -50
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -224
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -216
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -545
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -345
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -248
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -328
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -203
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -127
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -493
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -7
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -37
- msprobe/pytorch/common/utils.py +305 -224
- msprobe/pytorch/compare/distributed_compare.py +66 -111
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -36
- msprobe/pytorch/compare/pt_compare.py +50 -0
- msprobe/pytorch/debugger/debugger_config.py +95 -86
- msprobe/pytorch/debugger/precision_debugger.py +125 -95
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -67
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -98
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -183
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -102
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -203
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -31
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -0
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -109
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1876
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -100
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -75
- msprobe/pytorch/hook_module/wrap_functional.py +105 -108
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -73
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -72
- msprobe/pytorch/hook_module/wrap_torch.py +86 -88
- msprobe/pytorch/hook_module/wrap_vf.py +62 -64
- msprobe/pytorch/module_processer.py +138 -98
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -273
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -186
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -187
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -259
- msprobe/pytorch/parse_tool/lib/config.py +52 -51
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -367
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -90
- msprobe/pytorch/pt_config.py +188 -93
- msprobe/pytorch/service.py +246 -167
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/config/README.md +0 -397
- msprobe/mindspore/doc/dump.md +0 -65
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -269
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/dump.md +0 -207
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -176
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
8
|
+
from msprobe.mindspore.common.log import logger
|
|
9
|
+
from msprobe.core.common.const import CompareConst, MsCompareConst
|
|
10
|
+
|
|
11
|
+
class CompareResult:
|
|
12
|
+
def __init__(self, compare_value, pass_status, err_msg):
|
|
13
|
+
self.compare_value = compare_value
|
|
14
|
+
self.pass_status = pass_status
|
|
15
|
+
self.err_msg = err_msg
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseCompareAlgorithm(ABC):
|
|
19
|
+
def __init__(self) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.compare_algorithm_name = None
|
|
22
|
+
self.err_msg_mapping = {
|
|
23
|
+
CompareConst.COSINE: {
|
|
24
|
+
CompareConst.PASS: "",
|
|
25
|
+
CompareConst.ERROR: f"cosine similarity is less than threshold: {CompareConst.COS_THRESHOLD} ",
|
|
26
|
+
CompareConst.SKIP: "two inputs are not valid for computing cosine similarity, skip comparing ",
|
|
27
|
+
},
|
|
28
|
+
CompareConst.MAX_ABS_ERR: {
|
|
29
|
+
CompareConst.PASS: "",
|
|
30
|
+
CompareConst.ERROR: "max absolute difference is greater than " \
|
|
31
|
+
f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
|
|
32
|
+
CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
|
|
33
|
+
},
|
|
34
|
+
CompareConst.MAX_RELATIVE_ERR: {
|
|
35
|
+
CompareConst.PASS: "",
|
|
36
|
+
CompareConst.ERROR: "",
|
|
37
|
+
CompareConst.SKIP: "",
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def __call__(self, bench_compute_element, tested_compute_element):
|
|
42
|
+
'''
|
|
43
|
+
Args:
|
|
44
|
+
bench_compute_element: ComputeElement
|
|
45
|
+
tested_compute_element: ComputeElement
|
|
46
|
+
|
|
47
|
+
Return:
|
|
48
|
+
compare_result: CompareResult
|
|
49
|
+
'''
|
|
50
|
+
if self.check_validity(bench_compute_element, tested_compute_element):
|
|
51
|
+
compare_value = self.run_compare(bench_compute_element, tested_compute_element)
|
|
52
|
+
pass_status = self.check_pass(compare_value)
|
|
53
|
+
else:
|
|
54
|
+
logger.warning(f"not suitable for computing {self.compare_algorithm_name}, skip this.")
|
|
55
|
+
compare_value = None
|
|
56
|
+
pass_status = CompareConst.SKIP
|
|
57
|
+
|
|
58
|
+
err_msg = self.err_msg_mapping.get(self.compare_algorithm_name).get(pass_status)
|
|
59
|
+
|
|
60
|
+
compare_result = CompareResult(compare_value, pass_status, err_msg)
|
|
61
|
+
return compare_result
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def convert_to_np_float64_ndarray(tensor):
|
|
65
|
+
if isinstance(tensor, mindspore.Tensor):
|
|
66
|
+
ndarray = tensor.astype(mindspore.float64).numpy()
|
|
67
|
+
elif isinstance(tensor, torch.Tensor):
|
|
68
|
+
ndarray = tensor.to(torch.float64, copy=True).numpy()
|
|
69
|
+
else:
|
|
70
|
+
err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
|
|
71
|
+
"input is not mindspore.Tensor or torch.Tensor"
|
|
72
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
73
|
+
return ndarray
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def check_two_tensor(bench_compute_element, tested_compute_element):
|
|
77
|
+
bench_parameter = bench_compute_element.get_parameter()
|
|
78
|
+
tested_parameter = tested_compute_element.get_parameter()
|
|
79
|
+
|
|
80
|
+
bench_is_tensor = isinstance(bench_parameter, (mindspore.Tensor, torch.Tensor))
|
|
81
|
+
tested_is_tensor = isinstance(tested_parameter, (mindspore.Tensor, torch.Tensor))
|
|
82
|
+
shape_same = bench_compute_element.get_shape() == tested_compute_element.get_shape()
|
|
83
|
+
return bench_is_tensor and tested_is_tensor and shape_same
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
87
|
+
'''
|
|
88
|
+
Args:
|
|
89
|
+
bench_compute_element: ComputeElement
|
|
90
|
+
tested_compute_element: ComputeElement
|
|
91
|
+
|
|
92
|
+
Return:
|
|
93
|
+
check_res: boolean
|
|
94
|
+
'''
|
|
95
|
+
raise NotImplementedError
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
99
|
+
'''
|
|
100
|
+
Args:
|
|
101
|
+
bench_compute_element: ComputeElement
|
|
102
|
+
tested_compute_element: ComputeElement
|
|
103
|
+
|
|
104
|
+
Return:
|
|
105
|
+
compare_value: float/int
|
|
106
|
+
'''
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def check_pass(self, compare_value):
|
|
111
|
+
'''
|
|
112
|
+
Args:
|
|
113
|
+
compare_value: float/int
|
|
114
|
+
|
|
115
|
+
Return:
|
|
116
|
+
pass_status: str
|
|
117
|
+
'''
|
|
118
|
+
raise NotImplementedError
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class CosineSimilarityCompareAlgorithm(BaseCompareAlgorithm):
|
|
122
|
+
def __init__(self) -> None:
|
|
123
|
+
super().__init__()
|
|
124
|
+
self.compare_algorithm_name = CompareConst.COSINE
|
|
125
|
+
|
|
126
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
127
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
128
|
+
|
|
129
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
130
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
131
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
132
|
+
|
|
133
|
+
bench_norm = np.linalg.norm(bench_ndarray)
|
|
134
|
+
tested_norm = np.linalg.norm(tested_ndarray)
|
|
135
|
+
dot_product = np.dot(bench_ndarray.flatten(), tested_ndarray.flatten())
|
|
136
|
+
cosine_similarity = (MsCompareConst.EPSILON + dot_product) / (MsCompareConst.EPSILON + bench_norm * tested_norm)
|
|
137
|
+
return cosine_similarity
|
|
138
|
+
|
|
139
|
+
def check_pass(self, compare_value):
|
|
140
|
+
if compare_value > CompareConst.COS_THRESHOLD:
|
|
141
|
+
return CompareConst.PASS
|
|
142
|
+
else:
|
|
143
|
+
return CompareConst.ERROR
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class MaxAbsoluteDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
147
|
+
def __init__(self) -> None:
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.compare_algorithm_name = CompareConst.MAX_ABS_ERR
|
|
150
|
+
|
|
151
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
152
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
153
|
+
|
|
154
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
155
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
156
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
157
|
+
|
|
158
|
+
max_absolute_diff = np.max(np.abs(bench_ndarray - tested_ndarray))
|
|
159
|
+
return max_absolute_diff
|
|
160
|
+
|
|
161
|
+
def check_pass(self, compare_value):
|
|
162
|
+
if compare_value < CompareConst.MAX_ABS_ERR_THRESHOLD:
|
|
163
|
+
return CompareConst.PASS
|
|
164
|
+
else:
|
|
165
|
+
return CompareConst.ERROR
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
|
|
169
|
+
def __init__(self) -> None:
|
|
170
|
+
super().__init__()
|
|
171
|
+
self.compare_algorithm_name = CompareConst.MAX_RELATIVE_ERR
|
|
172
|
+
|
|
173
|
+
def check_validity(self, bench_compute_element, tested_compute_element):
|
|
174
|
+
return self.check_two_tensor(bench_compute_element, tested_compute_element)
|
|
175
|
+
|
|
176
|
+
def run_compare(self, bench_compute_element, tested_compute_element):
|
|
177
|
+
bench_ndarray = self.convert_to_np_float64_ndarray(bench_compute_element.get_parameter())
|
|
178
|
+
tested_ndarray = self.convert_to_np_float64_ndarray(tested_compute_element.get_parameter())
|
|
179
|
+
|
|
180
|
+
abs_diff = np.abs(bench_ndarray - tested_ndarray)
|
|
181
|
+
bench_ndarray_nonzero = np.abs(bench_ndarray) + (bench_ndarray == 0) * MsCompareConst.EPSILON
|
|
182
|
+
max_relative_diff = np.max(abs_diff / bench_ndarray_nonzero)
|
|
183
|
+
return max_relative_diff
|
|
184
|
+
|
|
185
|
+
def check_pass(self, compare_value):
|
|
186
|
+
if compare_value < CompareConst.MAX_RELATIVE_ERR_THRESHOLD:
|
|
187
|
+
return CompareConst.PASS
|
|
188
|
+
else:
|
|
189
|
+
return CompareConst.ERROR
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
compare_algorithms = {
|
|
194
|
+
CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
|
|
195
|
+
CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
|
|
196
|
+
CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
|
|
197
|
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
def add_api_accuracy_checker_argument(parser):
|
|
2
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
3
|
+
help="<Required> The api param tool result file: generate from api param tool, "
|
|
4
|
+
"a json file.")
|
|
5
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
|
|
6
|
+
help="<optional> The ut task result out path.")
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from msprobe.mindspore.common.log import logger
|
|
8
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
9
|
+
from msprobe.core.common.file_utils import load_npy
|
|
10
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
|
|
11
|
+
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
12
|
+
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
13
|
+
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
14
|
+
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
15
|
+
MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
|
|
16
|
+
int_dtype_str_list)
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MstensorMetaData:
|
|
22
|
+
def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
|
|
23
|
+
self.dtype_str = dtype_str
|
|
24
|
+
self.npy_path = npy_path
|
|
25
|
+
self.maximum = maximum
|
|
26
|
+
self.minimum = minimum
|
|
27
|
+
self.shape = shape
|
|
28
|
+
|
|
29
|
+
class ComputeElement:
|
|
30
|
+
def __init__(self, compute_element_info=None, parameter=None):
|
|
31
|
+
self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
|
|
32
|
+
if parameter is not None:
|
|
33
|
+
self._init_with_parameter(parameter)
|
|
34
|
+
elif isinstance(compute_element_info, (list, dict)):
|
|
35
|
+
self._init_from_compute_element_info(compute_element_info)
|
|
36
|
+
elif compute_element_info is None:
|
|
37
|
+
self._init_from_null_compute_element_info()
|
|
38
|
+
else:
|
|
39
|
+
logger.error_log_with_exp(
|
|
40
|
+
"ComputeElement.__init__ failed: not init with parameter or compute_element info is not (list, dict)",
|
|
41
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def transfer_to_torch_tensor(ms_tensor):
|
|
45
|
+
'''
|
|
46
|
+
Args:
|
|
47
|
+
ms_tensor: mindspore.Tensor
|
|
48
|
+
Return:
|
|
49
|
+
torch_tensor: torch.Tensor
|
|
50
|
+
'''
|
|
51
|
+
ms_dtype = ms_tensor.dtype
|
|
52
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
53
|
+
if dtype_str not in dtype_str_to_torch_dtype:
|
|
54
|
+
err_msg = f"ComputeElement.transfer_to_torch_tensor failed: no matching torch dtype for {dtype_str}"
|
|
55
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
56
|
+
else:
|
|
57
|
+
torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
|
|
58
|
+
|
|
59
|
+
if dtype_str in float_dtype_str_list:
|
|
60
|
+
middle_dtype = mindspore.float64
|
|
61
|
+
elif dtype_str in int_dtype_str_list:
|
|
62
|
+
middle_dtype = mindspore.int64
|
|
63
|
+
else:
|
|
64
|
+
middle_dtype = mindspore.uint64
|
|
65
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
66
|
+
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
67
|
+
return torch_tensor
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def transfer_to_mindspore_tensor(torch_tensor):
|
|
71
|
+
'''
|
|
72
|
+
Args:
|
|
73
|
+
torch_tensor: torch.Tensor
|
|
74
|
+
|
|
75
|
+
Return:
|
|
76
|
+
ms_tensor: mindspore.Tensor
|
|
77
|
+
'''
|
|
78
|
+
torch_dtype = torch_tensor.dtype
|
|
79
|
+
dtype_str = torch_dtype_to_dtype_str.get(torch_dtype)
|
|
80
|
+
if dtype_str not in dtype_str_to_ms_dtype:
|
|
81
|
+
err_msg = \
|
|
82
|
+
f"ComputeElement._transfer_to_mindspore_tensor failed: no matching mindspore dtype for {dtype_str}"
|
|
83
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
84
|
+
else:
|
|
85
|
+
ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
|
|
86
|
+
|
|
87
|
+
if dtype_str in float_dtype_str_list:
|
|
88
|
+
middle_dtype = torch.float64
|
|
89
|
+
elif dtype_str in int_dtype_str_list:
|
|
90
|
+
middle_dtype = torch.int64
|
|
91
|
+
np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
|
|
92
|
+
ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
|
|
93
|
+
return ms_tensor
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def convert_inf_to_real_num(value, dtype_str):
|
|
97
|
+
if value == float("inf"):
|
|
98
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
99
|
+
value = np.finfo(np_dtype).max
|
|
100
|
+
elif value == float("-inf"):
|
|
101
|
+
np_dtype = dtype_str_to_np_dtype.get(dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
102
|
+
value = np.finfo(np_dtype).min
|
|
103
|
+
return value
|
|
104
|
+
|
|
105
|
+
def get_parameter(self, get_origin=True, tensor_platform=Const.MS_FRAMEWORK):
|
|
106
|
+
'''
|
|
107
|
+
Args:
|
|
108
|
+
get_origin: boolean
|
|
109
|
+
tensor_platform: str, Union["mindspore", "pytorch"]
|
|
110
|
+
|
|
111
|
+
Return:
|
|
112
|
+
parameter: Union[int, float, str, slice, tuple, torch.Tensor, mindspore.Tensor]
|
|
113
|
+
'''
|
|
114
|
+
if self.parameter is None:
|
|
115
|
+
return self.parameter
|
|
116
|
+
if isinstance(self.parameter, tuple):
|
|
117
|
+
return tuple([compute_element.get_parameter(get_origin=get_origin, tensor_platform=tensor_platform)
|
|
118
|
+
for compute_element in self.parameter])
|
|
119
|
+
elif isinstance(self.parameter, self.supported_parameter_type):
|
|
120
|
+
parameter_tmp = self.parameter
|
|
121
|
+
elif isinstance(self.parameter, MstensorMetaData):
|
|
122
|
+
mstensor_meta_data = self.parameter
|
|
123
|
+
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
124
|
+
if global_context.get_is_constructed():
|
|
125
|
+
np_dtype = dtype_str_to_np_dtype.get(mstensor_meta_data.dtype_str, DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE)
|
|
126
|
+
ndarray = self._construct_ndarray(mstensor_meta_data.shape, mstensor_meta_data.maximum,
|
|
127
|
+
mstensor_meta_data.minimum, np_dtype)
|
|
128
|
+
else:
|
|
129
|
+
ndarray = load_npy(mstensor_meta_data.npy_path)
|
|
130
|
+
parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
|
|
131
|
+
else:
|
|
132
|
+
err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
|
|
133
|
+
"(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
|
|
134
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
135
|
+
|
|
136
|
+
# if necessary, do transfer
|
|
137
|
+
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
138
|
+
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
139
|
+
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
|
|
140
|
+
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
141
|
+
else:
|
|
142
|
+
parameter = parameter_tmp
|
|
143
|
+
|
|
144
|
+
return parameter
|
|
145
|
+
|
|
146
|
+
def get_shape(self):
|
|
147
|
+
return self.shape
|
|
148
|
+
|
|
149
|
+
def get_dtype(self):
|
|
150
|
+
return self.dtype_str
|
|
151
|
+
|
|
152
|
+
def _construct_ndarray(self, shape, maximum, minimum, np_dtype):
|
|
153
|
+
shape = tuple(shape)
|
|
154
|
+
np.random.seed(42)
|
|
155
|
+
if np_dtype == np.bool_:
|
|
156
|
+
ndarray = np.random.rand(*shape) > 0.5
|
|
157
|
+
else:
|
|
158
|
+
maximum = self.convert_inf_to_real_num(maximum, np_dtype)
|
|
159
|
+
minimum = self.convert_inf_to_real_num(minimum, np_dtype)
|
|
160
|
+
ndarray = np.random.uniform(minimum, maximum, shape).astype(np_dtype)
|
|
161
|
+
return ndarray
|
|
162
|
+
|
|
163
|
+
def _init_from_null_compute_element_info(self):
|
|
164
|
+
self.parameter = None
|
|
165
|
+
self.shape = tuple()
|
|
166
|
+
self.dtype = "None"
|
|
167
|
+
|
|
168
|
+
def _init_from_compute_element_info(self, compute_element_info):
|
|
169
|
+
'''
|
|
170
|
+
Args:
|
|
171
|
+
compute_element_info: Union[list, dict]
|
|
172
|
+
|
|
173
|
+
Return:
|
|
174
|
+
void
|
|
175
|
+
|
|
176
|
+
init member attributes: self.shape, self.dtype_str, self.parameter
|
|
177
|
+
'''
|
|
178
|
+
if isinstance(compute_element_info, list):
|
|
179
|
+
self.shape = tuple()
|
|
180
|
+
self.dtype_str = TUPLE_TYPE_STR
|
|
181
|
+
self.parameter = tuple([ComputeElement(compute_element_info=sub_info)
|
|
182
|
+
for sub_info in compute_element_info])
|
|
183
|
+
else:
|
|
184
|
+
type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
|
|
185
|
+
accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
|
|
186
|
+
|
|
187
|
+
if type_str == MINDSPORE_TENSOR_TYPE_STR:
|
|
188
|
+
self._init_from_mstensor_compute_element_info(compute_element_info)
|
|
189
|
+
else: # type_str in ("slice", "int", "float", "bool")
|
|
190
|
+
value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
|
|
191
|
+
self.shape = tuple()
|
|
192
|
+
self.dtype_str = type_str
|
|
193
|
+
self.parameter = slice(*tuple(value)) if type_str == "slice" else value
|
|
194
|
+
|
|
195
|
+
def _init_from_mstensor_compute_element_info(self, compute_element_info):
|
|
196
|
+
'''
|
|
197
|
+
do not load real tensor, only record meta data
|
|
198
|
+
'''
|
|
199
|
+
dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
|
|
200
|
+
accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
|
|
201
|
+
shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
|
|
202
|
+
accepted_type=(list,))
|
|
203
|
+
if global_context.get_is_constructed():
|
|
204
|
+
maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
|
|
205
|
+
accepted_type=(int, float))
|
|
206
|
+
minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
|
|
207
|
+
accepted_type=(int, float))
|
|
208
|
+
|
|
209
|
+
npy_path = None
|
|
210
|
+
else:
|
|
211
|
+
maximum, minimum = None, None
|
|
212
|
+
data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
|
|
213
|
+
"data_name field in api_info.json", accepted_type=(str,))
|
|
214
|
+
npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
|
|
215
|
+
mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
|
|
216
|
+
self.parameter = mstensor_meta_data
|
|
217
|
+
self.dtype_str = dtype_str
|
|
218
|
+
self.shape = tuple(shape)
|
|
219
|
+
|
|
220
|
+
def _init_with_parameter(self, parameter):
|
|
221
|
+
self.parameter = parameter
|
|
222
|
+
if not isinstance(parameter, self.supported_parameter_type):
|
|
223
|
+
err_msg = "ComputeElement._init_with_parameter failed: " \
|
|
224
|
+
"parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
|
|
225
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
226
|
+
if isinstance(parameter, mindspore.Tensor):
|
|
227
|
+
self.shape = tuple(parameter.shape)
|
|
228
|
+
self.dtype_str = ms_dtype_to_dtype_str.get(parameter.dtype)
|
|
229
|
+
elif isinstance(parameter, torch.Tensor):
|
|
230
|
+
self.shape = tuple(parameter.shape)
|
|
231
|
+
self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
|
|
232
|
+
elif isinstance(parameter, tuple):
|
|
233
|
+
self.shape = tuple()
|
|
234
|
+
self.dtype_str = TUPLE_TYPE_STR
|
|
235
|
+
self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
|
|
236
|
+
else:
|
|
237
|
+
self.shape = tuple()
|
|
238
|
+
self.dtype_str = \
|
|
239
|
+
TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def api_checker_main(args):
|
|
5
|
+
api_accuracy_checker = ApiAccuracyChecker()
|
|
6
|
+
api_accuracy_checker.parse(args.api_info_file)
|
|
7
|
+
api_accuracy_checker.run_and_compare()
|
|
8
|
+
api_accuracy_checker.to_detail_csv(args.out_path)
|
|
9
|
+
api_accuracy_checker.to_result_csv(args.out_path)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from mindspore.common import dtype as mstype
|
|
2
|
+
import numpy as np
|
|
3
|
+
import mindspore
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
INT8 = "Int8"
|
|
7
|
+
UINT8 = "UInt8"
|
|
8
|
+
INT16 = "Int16"
|
|
9
|
+
UINT16 = "UInt16"
|
|
10
|
+
INT32 = "Int32"
|
|
11
|
+
UINT32 = "UInt32"
|
|
12
|
+
INT64 = "Int64"
|
|
13
|
+
UINT64 = "UInt64"
|
|
14
|
+
FLOAT16 = "Float16"
|
|
15
|
+
FLOAT32 = "Float32"
|
|
16
|
+
FLOAT64 = "Float64"
|
|
17
|
+
BOOL = "Bool"
|
|
18
|
+
BFLOAT16 = "BFloat16"
|
|
19
|
+
INT4 = "Int4"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
dtype_str_to_ms_dtype = {
|
|
23
|
+
INT8: mstype.int8,
|
|
24
|
+
UINT8: mstype.uint8,
|
|
25
|
+
INT16: mstype.int16,
|
|
26
|
+
UINT16: mstype.uint16,
|
|
27
|
+
INT32: mstype.int32,
|
|
28
|
+
UINT32: mstype.uint32,
|
|
29
|
+
INT64: mstype.int64,
|
|
30
|
+
UINT64: mstype.uint64,
|
|
31
|
+
FLOAT16: mstype.float16,
|
|
32
|
+
FLOAT32: mstype.float32,
|
|
33
|
+
FLOAT64: mstype.float64,
|
|
34
|
+
BOOL: mstype.bool_,
|
|
35
|
+
BFLOAT16: mstype.bfloat16,
|
|
36
|
+
INT4: mstype.qint4x2
|
|
37
|
+
}
|
|
38
|
+
ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
dtype_str_to_np_dtype = {
|
|
42
|
+
INT8: np.int8,
|
|
43
|
+
UINT8: np.uint8,
|
|
44
|
+
INT16: np.int16,
|
|
45
|
+
UINT16: np.uint16,
|
|
46
|
+
INT32: np.int32,
|
|
47
|
+
UINT32: np.uint32,
|
|
48
|
+
INT64: np.int64,
|
|
49
|
+
UINT64: np.uint64,
|
|
50
|
+
FLOAT16: np.float16,
|
|
51
|
+
FLOAT32: np.float32,
|
|
52
|
+
FLOAT64: np.float64,
|
|
53
|
+
BOOL: np.bool_
|
|
54
|
+
}
|
|
55
|
+
np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
|
|
56
|
+
|
|
57
|
+
dtype_str_to_torch_dtype = {
|
|
58
|
+
INT8: torch.int8,
|
|
59
|
+
UINT8: torch.uint8,
|
|
60
|
+
INT16: torch.int16,
|
|
61
|
+
INT32: torch.int32,
|
|
62
|
+
INT64: torch.int64,
|
|
63
|
+
FLOAT16: torch.float16,
|
|
64
|
+
FLOAT32: torch.float32,
|
|
65
|
+
FLOAT64: torch.float64,
|
|
66
|
+
BOOL: torch.bool,
|
|
67
|
+
BFLOAT16: torch.bfloat16,
|
|
68
|
+
}
|
|
69
|
+
torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
|
|
70
|
+
|
|
71
|
+
MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
|
|
72
|
+
BOOL_TYPE_STR = "bool"
|
|
73
|
+
INT_TYPE_STR = "int"
|
|
74
|
+
FLOAT_TYPE_STR = "float"
|
|
75
|
+
SLICE_TYPE_STR = "slice"
|
|
76
|
+
TUPLE_TYPE_STR = "tuple"
|
|
77
|
+
STR_TYPE_STR = "str"
|
|
78
|
+
|
|
79
|
+
api_info_type_str_to_type = {
|
|
80
|
+
MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
|
|
81
|
+
BOOL_TYPE_STR: bool,
|
|
82
|
+
INT_TYPE_STR: int,
|
|
83
|
+
FLOAT_TYPE_STR: float,
|
|
84
|
+
SLICE_TYPE_STR: slice,
|
|
85
|
+
STR_TYPE_STR: str,
|
|
86
|
+
}
|
|
87
|
+
type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
|
|
88
|
+
|
|
89
|
+
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
|
|
90
|
+
DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
|
|
91
|
+
DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
|
|
92
|
+
|
|
93
|
+
float_dtype_str_list = [
|
|
94
|
+
FLOAT16,
|
|
95
|
+
FLOAT32,
|
|
96
|
+
FLOAT64,
|
|
97
|
+
BFLOAT16,
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
int_dtype_str_list = [
|
|
101
|
+
INT8,
|
|
102
|
+
INT16,
|
|
103
|
+
INT32,
|
|
104
|
+
INT64,
|
|
105
|
+
BOOL,
|
|
106
|
+
INT4,
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
uint_dtype_str_list = [
|
|
110
|
+
UINT8,
|
|
111
|
+
UINT16,
|
|
112
|
+
UINT32,
|
|
113
|
+
UINT64,
|
|
114
|
+
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
|
|
4
|
+
from msprobe.mindspore.common.log import logger
|
|
5
|
+
|
|
6
|
+
def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
|
|
7
|
+
'''
|
|
8
|
+
Args:
|
|
9
|
+
dict_instance: dict, dict parsed from input json
|
|
10
|
+
key: str
|
|
11
|
+
key_description: str
|
|
12
|
+
accepted_type: tuple
|
|
13
|
+
accepted_value: Union[tuple, list]
|
|
14
|
+
|
|
15
|
+
Return:
|
|
16
|
+
value, the corresponding value of "key" in "dict_instance"
|
|
17
|
+
|
|
18
|
+
Exception:
|
|
19
|
+
raise ApiAccuracyCheckerException.ParseJsonFailed error when
|
|
20
|
+
1. dict_instance is not a dict
|
|
21
|
+
2. value is None
|
|
22
|
+
3. value is not accepted type
|
|
23
|
+
4. value is not accepted value
|
|
24
|
+
'''
|
|
25
|
+
parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
|
|
26
|
+
if not isinstance(dict_instance, dict):
|
|
27
|
+
logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
|
|
28
|
+
value = dict_instance.get(key)
|
|
29
|
+
if value is None:
|
|
30
|
+
logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
|
|
31
|
+
parse_failed_exception)
|
|
32
|
+
elif accepted_type is not None and not isinstance(value, accepted_type):
|
|
33
|
+
logger.error_log_with_exp(
|
|
34
|
+
f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
|
|
35
|
+
parse_failed_exception)
|
|
36
|
+
elif accepted_value is not None and value not in accepted_value:
|
|
37
|
+
logger.error_log_with_exp(
|
|
38
|
+
f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
|
|
39
|
+
parse_failed_exception)
|
|
40
|
+
return value
|
|
41
|
+
|
|
42
|
+
def convert_to_tuple(input):
|
|
43
|
+
if isinstance(input, (tuple, list)):
|
|
44
|
+
return tuple(input)
|
|
45
|
+
else:
|
|
46
|
+
input_list = [input]
|
|
47
|
+
return tuple(input_list)
|
|
48
|
+
|
|
49
|
+
def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
50
|
+
'''
|
|
51
|
+
Args:
|
|
52
|
+
compute_element_list: List[ComputeElement]
|
|
53
|
+
forward_or_backward: str, Union["forward", "backward"]
|
|
54
|
+
'''
|
|
55
|
+
trimmed_list = []
|
|
56
|
+
for compute_element in compute_element_list:
|
|
57
|
+
if compute_element.get_parameter() is None or \
|
|
58
|
+
(forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
|
|
59
|
+
# trim case: 1. parameter is None. 2. backward output has non float parameter
|
|
60
|
+
continue
|
|
61
|
+
trimmed_list.append(compute_element)
|
|
62
|
+
return trimmed_list
|
|
63
|
+
|
|
64
|
+
class GlobalContext:
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self.is_constructed = True
|
|
67
|
+
self.dump_data_dir = ""
|
|
68
|
+
|
|
69
|
+
def init(self, is_constructed, dump_data_dir):
|
|
70
|
+
self.is_constructed = is_constructed
|
|
71
|
+
self.dump_data_dir = dump_data_dir
|
|
72
|
+
|
|
73
|
+
def get_dump_data_dir(self):
|
|
74
|
+
return self.dump_data_dir
|
|
75
|
+
|
|
76
|
+
def get_is_constructed(self):
|
|
77
|
+
return self.is_constructed
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
global_context = GlobalContext()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope
|
|
2
|
+
from msprobe.core.common.const import Const
|
|
3
|
+
from msprobe.mindspore.common.log import logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CellProcessor:
|
|
7
|
+
cell_count = {}
|
|
8
|
+
|
|
9
|
+
def __init__(self, scope):
|
|
10
|
+
if isinstance(scope, ModuleRangeScope):
|
|
11
|
+
self.scope = scope
|
|
12
|
+
else:
|
|
13
|
+
self.scope = None
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def set_cell_count(cell_name):
|
|
17
|
+
if cell_name not in CellProcessor.cell_count:
|
|
18
|
+
CellProcessor.cell_count[cell_name] = 0
|
|
19
|
+
else:
|
|
20
|
+
CellProcessor.cell_count[cell_name] += 1
|
|
21
|
+
return CellProcessor.cell_count[cell_name]
|
|
22
|
+
|
|
23
|
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
24
|
+
def begin_hook(cell, input):
|
|
25
|
+
index = self.set_cell_count(name_prefix)
|
|
26
|
+
cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
|
|
27
|
+
if self.scope:
|
|
28
|
+
self.scope.begin_module(full_name)
|
|
29
|
+
|
|
30
|
+
def end_hook(cell, input, output):
|
|
31
|
+
if self.scope:
|
|
32
|
+
self.scope.end_module(cell.mindstudio_reserved_name)
|
|
33
|
+
|
|
34
|
+
return begin_hook if Const.START == start_or_stop else end_hook
|