mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,299 +1,473 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
if
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import multiprocessing
|
|
17
|
+
import os
|
|
18
|
+
import pandas as pd
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from msprobe.core.common.file_utils import load_json
|
|
21
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
22
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
|
|
25
|
+
from msprobe.core.common.file_utils import remove_path
|
|
26
|
+
from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
|
|
27
|
+
check_stack_json_str
|
|
28
|
+
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
29
|
+
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
|
|
30
|
+
from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
|
|
31
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
|
|
32
|
+
get_error_message
|
|
33
|
+
from msprobe.core.advisor.advisor import Advisor
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Comparator:
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
43
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
44
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
45
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
46
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
47
|
+
npu_ops_all.get(ms_op_name).get('struct')[2],
|
|
48
|
+
bench_ops_all.get(bench_op_name).get('struct')[2],
|
|
49
|
+
CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
|
|
50
|
+
== bench_ops_all.get(bench_op_name).get('struct')[2]
|
|
51
|
+
else CompareConst.DIFF]
|
|
52
|
+
if args[0]:
|
|
53
|
+
result_item.extend(args[1])
|
|
54
|
+
else:
|
|
55
|
+
result_item.append(CompareConst.NONE)
|
|
56
|
+
return result_item
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
|
|
60
|
+
err_msg = ""
|
|
61
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
62
|
+
warning_flag = False
|
|
63
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
64
|
+
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
65
|
+
diff = npu_val - bench_val
|
|
66
|
+
if bench_val != 0:
|
|
67
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
68
|
+
else:
|
|
69
|
+
relative = "N/A"
|
|
70
|
+
result_item[start_idx + i] = diff
|
|
71
|
+
result_item[start_idx + i + 4] = relative
|
|
72
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
73
|
+
if magnitude_diff > 0.5:
|
|
74
|
+
warning_flag = True
|
|
75
|
+
else:
|
|
76
|
+
result_item[start_idx + i] = CompareConst.NONE
|
|
77
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
78
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
79
|
+
for i in range(start_idx, len(result_item)):
|
|
80
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
81
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
82
|
+
result_item.append(accuracy_check)
|
|
83
|
+
result_item.append(err_msg)
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
|
|
87
|
+
if md5_compare:
|
|
88
|
+
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
|
|
89
|
+
elif summary_compare:
|
|
90
|
+
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
91
|
+
else:
|
|
92
|
+
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
93
|
+
|
|
94
|
+
all_mode_bool = not (summary_compare or md5_compare)
|
|
95
|
+
if stack_mode:
|
|
96
|
+
if all_mode_bool:
|
|
97
|
+
header.append(CompareConst.STACK)
|
|
98
|
+
header.append(CompareConst.DATA_NAME)
|
|
99
|
+
else:
|
|
100
|
+
header.append(CompareConst.STACK)
|
|
101
|
+
else:
|
|
102
|
+
if all_mode_bool:
|
|
103
|
+
for row in result:
|
|
104
|
+
del row[-2]
|
|
105
|
+
header.append(CompareConst.DATA_NAME)
|
|
106
|
+
else:
|
|
107
|
+
for row in result:
|
|
108
|
+
del row[-1]
|
|
109
|
+
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
110
|
+
return result_df
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
|
|
114
|
+
op_data = json_data['data'][op_name]
|
|
115
|
+
check_dump_json_str(op_data, op_name)
|
|
116
|
+
op_parsed_list = read_op(op_data, op_name)
|
|
117
|
+
|
|
118
|
+
stack_info = stack_json_data.get(op_name)
|
|
119
|
+
if stack_info is not None:
|
|
120
|
+
check_stack_json_str(stack_info, op_name)
|
|
121
|
+
op_parsed_list.append({
|
|
122
|
+
'full_op_name': op_name,
|
|
123
|
+
'full_info': stack_info
|
|
124
|
+
})
|
|
125
|
+
|
|
126
|
+
merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
|
|
127
|
+
return merge_list
|
|
128
|
+
|
|
129
|
+
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
130
|
+
a_op_name = npu_dict["op_name"]
|
|
131
|
+
b_op_name = bench_dict["op_name"]
|
|
132
|
+
graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
|
|
133
|
+
|
|
134
|
+
frame_name = getattr(self, "frame_name")
|
|
135
|
+
if frame_name == "PTComparator":
|
|
136
|
+
from msprobe.pytorch.compare.match import graph_mapping
|
|
137
|
+
if graph_mode:
|
|
138
|
+
return graph_mapping.match(a_op_name[0], b_op_name[0])
|
|
139
|
+
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
140
|
+
if not fuzzy_match:
|
|
141
|
+
return a_op_name == b_op_name and struct_match
|
|
142
|
+
is_match = True
|
|
143
|
+
try:
|
|
144
|
+
is_match = fuzzy_check_op(a_op_name, b_op_name)
|
|
145
|
+
except Exception as err:
|
|
146
|
+
logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
|
|
147
|
+
is_match = False
|
|
148
|
+
return is_match and struct_match
|
|
149
|
+
|
|
150
|
+
def match_op(self, npu_queue, bench_queue, fuzzy_match):
|
|
151
|
+
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
152
|
+
if self.check_op(npu_queue[-1], b_op, fuzzy_match):
|
|
153
|
+
return len(npu_queue) - 1, b_index
|
|
154
|
+
if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
|
|
155
|
+
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
156
|
+
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
157
|
+
if self.check_op(n_op, bench_queue[-1], fuzzy_match):
|
|
158
|
+
return n_index, len(bench_queue) - 1
|
|
159
|
+
return -1, -1
|
|
160
|
+
|
|
161
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
|
|
162
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
163
|
+
npu_json_data = load_json(npu_json_path)
|
|
164
|
+
bench_json_data = load_json(bench_json_path)
|
|
165
|
+
stack_json_data = load_json(stack_json_path)
|
|
166
|
+
|
|
167
|
+
if fuzzy_match:
|
|
168
|
+
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
169
|
+
|
|
170
|
+
npu_ops_queue = []
|
|
171
|
+
bench_ops_queue = []
|
|
172
|
+
result = []
|
|
173
|
+
|
|
174
|
+
ops_npu_iter = iter(npu_json_data['data'])
|
|
175
|
+
ops_bench_iter = iter(bench_json_data['data'])
|
|
176
|
+
read_err_npu = True
|
|
177
|
+
read_err_bench = True
|
|
178
|
+
last_npu_ops_len = 0
|
|
179
|
+
last_bench_ops_len = 0
|
|
180
|
+
|
|
181
|
+
npu_api_nums = len(npu_json_data['data'])
|
|
182
|
+
progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
|
|
183
|
+
|
|
184
|
+
while True:
|
|
185
|
+
if not read_err_npu and not read_err_bench:
|
|
186
|
+
break
|
|
187
|
+
try:
|
|
188
|
+
last_npu_ops_len = len(npu_ops_queue)
|
|
189
|
+
op_name_npu = next(ops_npu_iter)
|
|
190
|
+
check_op_str_pattern_valid(op_name_npu)
|
|
191
|
+
read_err_npu = True
|
|
192
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
|
|
193
|
+
summary_compare, md5_compare)
|
|
194
|
+
if npu_merge_list:
|
|
195
|
+
npu_ops_queue.append(npu_merge_list)
|
|
196
|
+
except StopIteration:
|
|
197
|
+
read_err_npu = False
|
|
198
|
+
try:
|
|
199
|
+
last_bench_ops_len = len(bench_ops_queue)
|
|
200
|
+
op_name_bench = next(ops_bench_iter)
|
|
201
|
+
check_op_str_pattern_valid(op_name_bench)
|
|
202
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
|
|
203
|
+
summary_compare, md5_compare)
|
|
204
|
+
if bench_merge_list:
|
|
205
|
+
bench_ops_queue.append(bench_merge_list)
|
|
206
|
+
except StopIteration:
|
|
207
|
+
read_err_bench = False
|
|
208
|
+
|
|
209
|
+
progress_bar.update(1)
|
|
210
|
+
|
|
211
|
+
# merge all boolean expressions
|
|
212
|
+
both_empty = not npu_ops_queue and not bench_ops_queue
|
|
213
|
+
no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
|
|
214
|
+
if both_empty or no_change:
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
# APIs in NPU and Bench models unconsistent judgment
|
|
218
|
+
if bool(npu_ops_queue) ^ bool(bench_ops_queue):
|
|
219
|
+
logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
|
|
220
|
+
break
|
|
221
|
+
|
|
222
|
+
n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
|
|
223
|
+
if n_match_point == -1 and b_match_point == -1:
|
|
224
|
+
continue
|
|
225
|
+
n_match_data = npu_ops_queue[n_match_point]
|
|
226
|
+
b_match_data = bench_ops_queue[b_match_point]
|
|
227
|
+
un_match_data = npu_ops_queue[0: n_match_point]
|
|
228
|
+
for npu_data in un_match_data:
|
|
229
|
+
get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
|
|
230
|
+
get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
|
|
231
|
+
del npu_ops_queue[0: n_match_point + 1]
|
|
232
|
+
del bench_ops_queue[0: b_match_point + 1]
|
|
233
|
+
if npu_ops_queue:
|
|
234
|
+
for npu_data in npu_ops_queue:
|
|
235
|
+
get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
|
|
236
|
+
|
|
237
|
+
result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
|
|
238
|
+
return result_df
|
|
239
|
+
|
|
240
|
+
def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
|
|
241
|
+
ops_all = {}
|
|
242
|
+
for op_name in json_data.get('data', {}):
|
|
243
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
|
|
244
|
+
md5_compare)
|
|
245
|
+
if merge_list:
|
|
246
|
+
input_index, output_index = 0, 0
|
|
247
|
+
for index, input_or_output in enumerate(merge_list['op_name']):
|
|
248
|
+
input_or_output_list = input_or_output.split(Const.SEP)
|
|
249
|
+
data_name = merge_list.get('data_name')
|
|
250
|
+
data_name = data_name[index] if data_name else None
|
|
251
|
+
if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
|
|
252
|
+
ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
|
|
253
|
+
'summary': merge_list.get('summary')[index],
|
|
254
|
+
'data_name': data_name,
|
|
255
|
+
'stack_info': merge_list.get('stack_info')}
|
|
256
|
+
input_index += 1
|
|
257
|
+
|
|
258
|
+
elif Const.OUTPUT in input_or_output_list:
|
|
259
|
+
ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
|
|
260
|
+
'summary': merge_list.get('summary')[index],
|
|
261
|
+
'data_name': data_name,
|
|
262
|
+
'stack_info': merge_list.get('stack_info')}
|
|
263
|
+
output_index += 1
|
|
264
|
+
return ops_all
|
|
265
|
+
|
|
266
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
|
|
267
|
+
result = []
|
|
268
|
+
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
269
|
+
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
270
|
+
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
271
|
+
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
272
|
+
has_stack = npu_stack_info and bench_stack_info
|
|
273
|
+
if md5_compare:
|
|
274
|
+
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
275
|
+
bench_ops_all, has_stack, npu_stack_info))
|
|
276
|
+
continue
|
|
277
|
+
if summary_compare:
|
|
278
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
279
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
280
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
281
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
282
|
+
" ", " ", " ", " ", " ", " ", " ", " "]
|
|
283
|
+
else:
|
|
284
|
+
result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
|
|
285
|
+
bench_ops_all.get(bench_op_name).get('struct')[0],
|
|
286
|
+
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
287
|
+
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
288
|
+
" ", " ", " ", " ", " "]
|
|
289
|
+
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
290
|
+
result_item.extend(npu_summary_data)
|
|
291
|
+
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
292
|
+
result_item.extend(bench_summary_data)
|
|
293
|
+
if summary_compare:
|
|
294
|
+
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
295
|
+
else:
|
|
296
|
+
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
297
|
+
result_item.append("")
|
|
298
|
+
if has_stack:
|
|
299
|
+
result_item.extend(npu_stack_info)
|
|
300
|
+
else:
|
|
301
|
+
result_item.append(CompareConst.NONE)
|
|
302
|
+
if not (summary_compare or md5_compare):
|
|
303
|
+
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
304
|
+
result.append(result_item)
|
|
305
|
+
elif ms_op_name not in npu_ops_all:
|
|
306
|
+
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
307
|
+
elif bench_op_name not in npu_ops_all:
|
|
308
|
+
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
|
|
312
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
313
|
+
npu_json_data = load_json(npu_json_path)
|
|
314
|
+
bench_json_data = load_json(bench_json_path)
|
|
315
|
+
stack_json_data = load_json(stack_json_path)
|
|
316
|
+
|
|
317
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
|
|
318
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
|
|
319
|
+
|
|
320
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
|
|
321
|
+
result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
|
|
322
|
+
return result_df
|
|
323
|
+
|
|
324
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
325
|
+
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
326
|
+
data_name = npu_bench_name_list[1]
|
|
327
|
+
error_file, relative_err, error_flag = None, None, False
|
|
328
|
+
if data_name == '-1' or data_name == -1: # 没有真实数据路径
|
|
329
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
330
|
+
error_flag = True
|
|
331
|
+
else:
|
|
332
|
+
try:
|
|
333
|
+
read_npy_data = getattr(self, "read_npy_data")
|
|
334
|
+
frame_name = getattr(self, "frame_name")
|
|
335
|
+
if frame_name == "MSComparator":
|
|
336
|
+
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
|
|
337
|
+
if self.cross_frame:
|
|
338
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
339
|
+
bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
|
|
340
|
+
else:
|
|
341
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
342
|
+
bench_op_name + Const.NUMPY_SUFFIX)
|
|
343
|
+
else:
|
|
344
|
+
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
|
|
345
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
|
|
346
|
+
except IOError as error:
|
|
347
|
+
error_file = error.filename
|
|
348
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
349
|
+
error_flag = True
|
|
350
|
+
except FileCheckException:
|
|
351
|
+
error_file = data_name
|
|
352
|
+
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
353
|
+
error_flag = True
|
|
354
|
+
|
|
355
|
+
n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
|
|
356
|
+
if not error_flag:
|
|
357
|
+
relative_err = get_relative_err(n_value, b_value)
|
|
358
|
+
n_value, b_value = reshape_value(n_value, b_value)
|
|
359
|
+
|
|
360
|
+
err_msg = get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=error_file)
|
|
361
|
+
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
|
|
362
|
+
|
|
363
|
+
if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
364
|
+
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
365
|
+
result_list.append(err_msg)
|
|
366
|
+
return result_list
|
|
367
|
+
|
|
368
|
+
def compare_core(self, input_parma, output_path, **kwargs):
|
|
369
|
+
"""
|
|
370
|
+
Compares data from multiple JSON files and generates a comparison report.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
|
|
374
|
+
"stack_path").
|
|
375
|
+
output_path (str): The path where the output Excel report will be saved.
|
|
376
|
+
**kwargs: Additional keyword arguments including:
|
|
377
|
+
- stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
|
|
378
|
+
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
379
|
+
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
380
|
+
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
381
|
+
- summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
|
|
382
|
+
- md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
"""
|
|
386
|
+
# get kwargs or set default value
|
|
387
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
388
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
389
|
+
suffix = kwargs.get('suffix', '')
|
|
390
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
391
|
+
summary_compare = kwargs.get('summary_compare', False)
|
|
392
|
+
md5_compare = kwargs.get('md5_compare', False)
|
|
393
|
+
|
|
394
|
+
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
395
|
+
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
396
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
397
|
+
remove_path(file_path)
|
|
398
|
+
highlight_dict = {'red_rows': [], 'yellow_rows': []}
|
|
399
|
+
|
|
400
|
+
npu_json = input_parma.get("npu_json_path")
|
|
401
|
+
bench_json = input_parma.get("bench_json_path")
|
|
402
|
+
stack_json = input_parma.get("stack_json_path")
|
|
403
|
+
if self.data_mapping:
|
|
404
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
|
|
405
|
+
summary_compare, md5_compare)
|
|
406
|
+
else:
|
|
407
|
+
result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
|
|
408
|
+
summary_compare, md5_compare)
|
|
409
|
+
|
|
410
|
+
if not result_df.values.tolist():
|
|
411
|
+
logger.warning("Can`t match any op.")
|
|
412
|
+
return
|
|
413
|
+
|
|
414
|
+
if not md5_compare and not summary_compare:
|
|
415
|
+
result_df = self._do_multi_process(input_parma, result_df)
|
|
416
|
+
|
|
417
|
+
logger.info("Highlight suspicious API/Module start.")
|
|
418
|
+
find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
|
|
419
|
+
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
420
|
+
logger.info("Highlight suspicious API/Module finish.")
|
|
421
|
+
|
|
422
|
+
if auto_analyze:
|
|
423
|
+
advisor = Advisor(result_df, output_path, suffix)
|
|
424
|
+
advisor.analysis()
|
|
425
|
+
|
|
426
|
+
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
427
|
+
cos_result = []
|
|
428
|
+
max_err_result = []
|
|
429
|
+
max_relative_err_result = []
|
|
430
|
+
err_mess = []
|
|
431
|
+
one_thousand_err_ratio_result = []
|
|
432
|
+
five_thousand_err_ratio_result = []
|
|
433
|
+
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
434
|
+
for i in range(len(result_df)):
|
|
435
|
+
npu_op_name = result_df.iloc[i, 0]
|
|
436
|
+
bench_op_name = result_df.iloc[i, 1]
|
|
437
|
+
if is_print_compare_log:
|
|
438
|
+
logger.info("start compare: {}".format(npu_op_name))
|
|
439
|
+
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
|
|
440
|
+
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
441
|
+
if is_print_compare_log:
|
|
442
|
+
logger.info(
|
|
443
|
+
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
444
|
+
one_thousand_err_ratio {}, "
|
|
445
|
+
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
446
|
+
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
447
|
+
cos_result.append(cos_sim)
|
|
448
|
+
max_err_result.append(max_abs_err)
|
|
449
|
+
max_relative_err_result.append(max_relative_err)
|
|
450
|
+
err_mess.append(err_msg)
|
|
451
|
+
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
452
|
+
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
453
|
+
|
|
454
|
+
cr = ComparisonResult(
|
|
455
|
+
cos_result=cos_result,
|
|
456
|
+
max_err_result=max_err_result,
|
|
457
|
+
max_relative_err_result=max_relative_err_result,
|
|
458
|
+
err_msgs=err_mess,
|
|
459
|
+
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
460
|
+
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
return _save_cmp_result(idx, cr, result_df, lock)
|
|
464
|
+
|
|
465
|
+
def _do_multi_process(self, input_parma, result_df):
|
|
466
|
+
try:
|
|
467
|
+
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
468
|
+
multiprocessing.Manager().RLock())
|
|
469
|
+
return result_df
|
|
470
|
+
except ValueError as e:
|
|
471
|
+
logger.error('result dataframe is not found.')
|
|
472
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
299
473
|
|