mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,295 +1,310 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
if
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
b_value = b_value
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
if
|
|
177
|
-
return CompareConst.
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
n_value
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
if
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
if
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import abc
|
|
17
|
+
import numpy as np
|
|
18
|
+
from msprobe.core.common.utils import format_value
|
|
19
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def handle_inf_nan(n_value, b_value):
|
|
24
|
+
"""处理inf和nan的数据"""
|
|
25
|
+
n_inf = np.isinf(n_value)
|
|
26
|
+
b_inf = np.isinf(b_value)
|
|
27
|
+
n_nan = np.isnan(n_value)
|
|
28
|
+
b_nan = np.isnan(b_value)
|
|
29
|
+
n_invalid = np.any(n_inf) or np.any(n_nan)
|
|
30
|
+
b_invalid = np.any(b_inf) or np.any(b_nan)
|
|
31
|
+
if n_invalid or b_invalid:
|
|
32
|
+
if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan):
|
|
33
|
+
n_value[n_inf] = 0
|
|
34
|
+
b_value[b_inf] = 0
|
|
35
|
+
n_value[n_nan] = 0
|
|
36
|
+
b_value[b_nan] = 0
|
|
37
|
+
else:
|
|
38
|
+
return CompareConst.NAN, CompareConst.NAN
|
|
39
|
+
return n_value, b_value
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_error_type(n_value, b_value, error_flag):
|
|
43
|
+
"""判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
|
|
44
|
+
if error_flag:
|
|
45
|
+
return CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
46
|
+
if n_value.size == 0: # 判断读取到的数据是否为空
|
|
47
|
+
return CompareConst.NONE, CompareConst.NONE, True
|
|
48
|
+
if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
|
|
49
|
+
return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, True
|
|
50
|
+
if not n_value.shape: # 判断数据是否为标量
|
|
51
|
+
return n_value, b_value, False
|
|
52
|
+
|
|
53
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
|
|
54
|
+
if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
|
|
55
|
+
return CompareConst.NAN, CompareConst.NAN, True
|
|
56
|
+
return n_value, b_value, False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def reshape_value(n_value, b_value):
|
|
60
|
+
"""返回reshape后的数据"""
|
|
61
|
+
if not n_value.shape: # 判断数据是否为标量
|
|
62
|
+
if n_value.dtype == bool:
|
|
63
|
+
n_value = n_value.astype(float)
|
|
64
|
+
b_value = b_value.astype(float)
|
|
65
|
+
return n_value, b_value
|
|
66
|
+
|
|
67
|
+
n_value = n_value.reshape(-1).astype(float)
|
|
68
|
+
b_value = b_value.reshape(-1).astype(float)
|
|
69
|
+
return n_value, b_value
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
|
|
73
|
+
"""获取异常情况的错误信息"""
|
|
74
|
+
if error_flag:
|
|
75
|
+
if n_value == CompareConst.READ_NONE:
|
|
76
|
+
if error_file:
|
|
77
|
+
return "Dump file: {} not found.".format(error_file)
|
|
78
|
+
return CompareConst.NO_BENCH
|
|
79
|
+
if n_value == CompareConst.NONE:
|
|
80
|
+
return "This is empty data, can not compare."
|
|
81
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
82
|
+
return "Shape of NPU and bench Tensor do not match. Skipped."
|
|
83
|
+
if n_value == CompareConst.NAN:
|
|
84
|
+
return "The position of inf or nan in NPU and bench Tensor do not match."
|
|
85
|
+
else:
|
|
86
|
+
if not n_value.shape:
|
|
87
|
+
return "This is type of scalar data, can not compare."
|
|
88
|
+
if n_value.dtype != b_value.dtype:
|
|
89
|
+
logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
|
|
90
|
+
return "Dtype of NPU and bench Tensor do not match."
|
|
91
|
+
return ""
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def npy_data_check(n_value, b_value):
|
|
95
|
+
error_message = ""
|
|
96
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
97
|
+
error_message += "Dump file is not ndarray.\n"
|
|
98
|
+
|
|
99
|
+
# 检查 n_value 和 b_value 是否为空
|
|
100
|
+
if not error_message and (n_value.size == 0 or b_value.size == 0):
|
|
101
|
+
error_message += "This is empty data, can not compare.\n"
|
|
102
|
+
|
|
103
|
+
if not error_message:
|
|
104
|
+
if not n_value.shape or not b_value.shape:
|
|
105
|
+
error_message += "This is type of scalar data, can not compare.\n"
|
|
106
|
+
if n_value.shape != b_value.shape:
|
|
107
|
+
error_message += "Shape of NPU and bench Tensor do not match.\n"
|
|
108
|
+
if n_value.dtype != b_value.dtype:
|
|
109
|
+
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
110
|
+
|
|
111
|
+
if not error_message:
|
|
112
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
|
|
113
|
+
# handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
|
|
114
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
115
|
+
error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
|
|
116
|
+
if error_message == "":
|
|
117
|
+
error_flag = False
|
|
118
|
+
else:
|
|
119
|
+
error_flag = True
|
|
120
|
+
return error_flag, error_message
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def statistics_data_check(result_dict):
|
|
124
|
+
error_message = ""
|
|
125
|
+
|
|
126
|
+
if result_dict.get(CompareConst.NPU_NAME) is None or result_dict.get(CompareConst.BENCH_NAME) is None:
|
|
127
|
+
error_message += "Dump file not found.\n"
|
|
128
|
+
|
|
129
|
+
if not result_dict.get(CompareConst.NPU_SHAPE) or not result_dict.get(CompareConst.BENCH_SHAPE):
|
|
130
|
+
error_message += "This is type of scalar data, can not compare.\n"
|
|
131
|
+
elif result_dict.get(CompareConst.NPU_SHAPE) != result_dict.get(CompareConst.BENCH_SHAPE):
|
|
132
|
+
error_message += "Tensor shapes do not match.\n"
|
|
133
|
+
|
|
134
|
+
if result_dict.get(CompareConst.NPU_DTYPE) != result_dict.get(CompareConst.BENCH_DTYPE):
|
|
135
|
+
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
136
|
+
|
|
137
|
+
if error_message == "":
|
|
138
|
+
error_flag = False
|
|
139
|
+
else:
|
|
140
|
+
error_flag = True
|
|
141
|
+
return error_flag, error_message
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class TensorComparisonBasic(abc.ABC):
|
|
145
|
+
"""NPU和bench中npy数据的比较模板"""
|
|
146
|
+
@abc.abstractmethod
|
|
147
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class GetCosineSimilarity(TensorComparisonBasic):
|
|
152
|
+
"""计算cosine相似度"""
|
|
153
|
+
@staticmethod
|
|
154
|
+
def correct_data(result):
|
|
155
|
+
if result == CompareConst.NAN:
|
|
156
|
+
return result
|
|
157
|
+
if float(result) > CompareConst.COSINE_THRESHOLD:
|
|
158
|
+
return round(float(result), 6)
|
|
159
|
+
return result
|
|
160
|
+
|
|
161
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
162
|
+
if error_flag:
|
|
163
|
+
if n_value == CompareConst.READ_NONE:
|
|
164
|
+
return CompareConst.NONE, ''
|
|
165
|
+
if n_value == CompareConst.NONE:
|
|
166
|
+
return CompareConst.UNSUPPORTED, ''
|
|
167
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
168
|
+
return CompareConst.SHAPE_UNMATCH, ''
|
|
169
|
+
if n_value == CompareConst.NAN:
|
|
170
|
+
return "N/A", ''
|
|
171
|
+
|
|
172
|
+
if not n_value.shape:
|
|
173
|
+
return CompareConst.UNSUPPORTED, ''
|
|
174
|
+
|
|
175
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
176
|
+
if len(n_value) == 1:
|
|
177
|
+
return CompareConst.UNSUPPORTED, "This tensor is scalar."
|
|
178
|
+
num = n_value.dot(b_value)
|
|
179
|
+
a_norm = np.linalg.norm(n_value)
|
|
180
|
+
b_norm = np.linalg.norm(b_value)
|
|
181
|
+
|
|
182
|
+
if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
|
|
183
|
+
return 1.0, ''
|
|
184
|
+
if a_norm <= Const.FLOAT_EPSILON:
|
|
185
|
+
return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.'
|
|
186
|
+
if b_norm <= Const.FLOAT_EPSILON:
|
|
187
|
+
return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.'
|
|
188
|
+
|
|
189
|
+
cos = num / (a_norm * b_norm)
|
|
190
|
+
if np.isnan(cos):
|
|
191
|
+
return CompareConst.NAN, 'Cannot compare by Cosine Similarity, the dump data has NaN.'
|
|
192
|
+
result = format_value(cos)
|
|
193
|
+
result = self.correct_data(result)
|
|
194
|
+
return 1.0 if float(result) > 0.99999 else result, ''
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class GetMaxAbsErr(TensorComparisonBasic):
|
|
198
|
+
"""计算最大绝对误差"""
|
|
199
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
200
|
+
if error_flag:
|
|
201
|
+
if n_value == CompareConst.READ_NONE:
|
|
202
|
+
return CompareConst.NONE, ""
|
|
203
|
+
if n_value == CompareConst.NONE:
|
|
204
|
+
return 0, ""
|
|
205
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
206
|
+
return CompareConst.SHAPE_UNMATCH, ""
|
|
207
|
+
if n_value == CompareConst.NAN:
|
|
208
|
+
return "N/A", ""
|
|
209
|
+
|
|
210
|
+
temp_res = n_value - b_value
|
|
211
|
+
max_value = np.max(np.abs(temp_res))
|
|
212
|
+
return format_value(max_value), ""
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def get_relative_err(n_value, b_value):
|
|
216
|
+
"""计算相对误差"""
|
|
217
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
218
|
+
if b_value.dtype not in CompareConst.FLOAT_TYPE:
|
|
219
|
+
n_value, b_value = n_value.astype(float), b_value.astype(float)
|
|
220
|
+
zero_mask = (b_value == 0)
|
|
221
|
+
b_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
222
|
+
n_value[zero_mask] += np.finfo(b_value.dtype).eps
|
|
223
|
+
relative_err = np.divide((n_value - b_value), b_value)
|
|
224
|
+
return np.abs(relative_err)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
228
|
+
"""计算最大相对误差"""
|
|
229
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
230
|
+
if error_flag:
|
|
231
|
+
if n_value == CompareConst.READ_NONE:
|
|
232
|
+
return CompareConst.NONE, ''
|
|
233
|
+
if n_value == CompareConst.NONE:
|
|
234
|
+
return 0, ''
|
|
235
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
236
|
+
return CompareConst.SHAPE_UNMATCH, ''
|
|
237
|
+
if n_value == CompareConst.NAN:
|
|
238
|
+
return "N/A", ''
|
|
239
|
+
|
|
240
|
+
if relative_err is None:
|
|
241
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
242
|
+
max_relative_err = np.max(np.abs(relative_err))
|
|
243
|
+
if np.isnan(max_relative_err):
|
|
244
|
+
message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
|
|
245
|
+
return CompareConst.NAN, message
|
|
246
|
+
return format_value(max_relative_err), ''
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class GetThousandErrRatio(TensorComparisonBasic):
|
|
250
|
+
"""计算相对误差小于千分之一的比例"""
|
|
251
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
252
|
+
if error_flag:
|
|
253
|
+
if n_value == CompareConst.READ_NONE:
|
|
254
|
+
return CompareConst.NONE, ""
|
|
255
|
+
if n_value == CompareConst.NONE:
|
|
256
|
+
return 0, ""
|
|
257
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
258
|
+
return CompareConst.SHAPE_UNMATCH, ""
|
|
259
|
+
if n_value == CompareConst.NAN:
|
|
260
|
+
return "N/A", ""
|
|
261
|
+
|
|
262
|
+
if not n_value.shape:
|
|
263
|
+
return CompareConst.NAN, ""
|
|
264
|
+
if relative_err is None:
|
|
265
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
266
|
+
if not np.size(relative_err):
|
|
267
|
+
return CompareConst.NAN, ""
|
|
268
|
+
return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
272
|
+
"""计算相对误差小于千分之五的比例"""
|
|
273
|
+
def apply(self, n_value, b_value, error_flag, relative_err=None):
|
|
274
|
+
if error_flag:
|
|
275
|
+
if n_value == CompareConst.READ_NONE:
|
|
276
|
+
return CompareConst.NONE, ""
|
|
277
|
+
if n_value == CompareConst.NONE:
|
|
278
|
+
return 0, ""
|
|
279
|
+
if n_value == CompareConst.SHAPE_UNMATCH:
|
|
280
|
+
return CompareConst.SHAPE_UNMATCH, ""
|
|
281
|
+
if n_value == CompareConst.NAN:
|
|
282
|
+
return "N/A", ""
|
|
283
|
+
|
|
284
|
+
if not n_value.shape:
|
|
285
|
+
return CompareConst.NAN, ""
|
|
286
|
+
if relative_err is None:
|
|
287
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
288
|
+
if not np.size(relative_err):
|
|
289
|
+
return CompareConst.NAN, ""
|
|
290
|
+
return format_value(
|
|
291
|
+
np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class CompareOps:
|
|
295
|
+
compare_ops = {
|
|
296
|
+
"cosine_similarity": GetCosineSimilarity(),
|
|
297
|
+
"max_abs_error": GetMaxAbsErr(),
|
|
298
|
+
"max_relative_error": GetMaxRelativeErr(),
|
|
299
|
+
"one_thousand_err_ratio": GetThousandErrRatio(),
|
|
300
|
+
"five_thousand_err_ratio": GetFiveThousandErrRatio()
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=None):
|
|
305
|
+
result_list = []
|
|
306
|
+
for op in CompareOps.compare_ops.values():
|
|
307
|
+
result, msg = op.apply(n_value, b_value, error_flag, relative_err=relative_err)
|
|
308
|
+
err_msg += msg
|
|
309
|
+
result_list.append(result)
|
|
310
|
+
return result_list, err_msg
|