mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
|
|
2
|
+
import multiprocessing
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import partial
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from msprobe.core.common.log import logger
|
|
8
|
+
from msprobe.core.common.utils import CompareException
|
|
9
|
+
from msprobe.core.common.const import CompareConst
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _handle_multi_process(func, input_parma, result_df, lock):
|
|
13
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
14
|
+
op_name_mapping_dict = read_dump_data(result_df)
|
|
15
|
+
|
|
16
|
+
df_chunk_size = len(result_df) // process_num
|
|
17
|
+
if df_chunk_size > 0:
|
|
18
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
19
|
+
else:
|
|
20
|
+
df_chunks = [result_df]
|
|
21
|
+
|
|
22
|
+
results = []
|
|
23
|
+
pool = multiprocessing.Pool(process_num)
|
|
24
|
+
|
|
25
|
+
def err_call(args):
|
|
26
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
27
|
+
try:
|
|
28
|
+
pool.terminate()
|
|
29
|
+
except OSError as e:
|
|
30
|
+
logger.error("pool terminate failed")
|
|
31
|
+
|
|
32
|
+
for process_idx, df_chunk in enumerate(df_chunks):
|
|
33
|
+
idx = df_chunk_size * process_idx
|
|
34
|
+
result = pool.apply_async(func,
|
|
35
|
+
args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
|
|
36
|
+
error_callback=err_call)
|
|
37
|
+
results.append(result)
|
|
38
|
+
final_results = [r.get() for r in results]
|
|
39
|
+
pool.close()
|
|
40
|
+
pool.join()
|
|
41
|
+
return pd.concat(final_results, ignore_index=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
45
|
+
process_num = int((multiprocessing.cpu_count() + 1) // 2)
|
|
46
|
+
df_chunk_size = len(result_df) // process_num
|
|
47
|
+
if df_chunk_size > 0:
|
|
48
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
49
|
+
else:
|
|
50
|
+
df_chunks = [result_df]
|
|
51
|
+
|
|
52
|
+
results = []
|
|
53
|
+
pool = multiprocessing.Pool(process_num)
|
|
54
|
+
|
|
55
|
+
def err_call(args):
|
|
56
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
57
|
+
try:
|
|
58
|
+
pool.terminate()
|
|
59
|
+
except OSError as e:
|
|
60
|
+
logger.error("pool terminate failed")
|
|
61
|
+
|
|
62
|
+
for df_chunk in df_chunks:
|
|
63
|
+
result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
|
|
64
|
+
results.append(result)
|
|
65
|
+
final_results = [r.get() for r in results]
|
|
66
|
+
pool.close()
|
|
67
|
+
pool.join()
|
|
68
|
+
return pd.concat(final_results, ignore_index=True)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def read_dump_data(result_df):
|
|
72
|
+
try:
|
|
73
|
+
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
74
|
+
npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
|
|
75
|
+
op_name_mapping_dict = {}
|
|
76
|
+
for index, _ in enumerate(npu_dump_name_list):
|
|
77
|
+
npu_dump_name = npu_dump_name_list[index]
|
|
78
|
+
npu_dump_tensor = npu_dump_tensor_list[index]
|
|
79
|
+
op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
|
|
80
|
+
return op_name_mapping_dict
|
|
81
|
+
except ValueError as e:
|
|
82
|
+
logger.error('result dataframe is not found.')
|
|
83
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
84
|
+
except IndexError as e:
|
|
85
|
+
logger.error('result dataframe elements can not be access.')
|
|
86
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class ComparisonResult:
|
|
90
|
+
cos_result: list
|
|
91
|
+
max_err_result: list
|
|
92
|
+
max_relative_err_result: list
|
|
93
|
+
err_msgs: list
|
|
94
|
+
one_thousand_err_ratio_result: list
|
|
95
|
+
five_thousand_err_ratio_result: list
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
99
|
+
"""
|
|
100
|
+
Save comparison results into the result DataFrame with thread safety.
|
|
101
|
+
Args:
|
|
102
|
+
offset: offset for index
|
|
103
|
+
result: data struct of ComparisonResult
|
|
104
|
+
result_df: result of DataFrame
|
|
105
|
+
lock: thread lock
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
comparison results in DataFrame
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
lock.acquire()
|
|
112
|
+
try:
|
|
113
|
+
for i, _ in enumerate(result.cos_result):
|
|
114
|
+
process_index = i + offset
|
|
115
|
+
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
116
|
+
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
117
|
+
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
118
|
+
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
119
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
|
|
120
|
+
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
|
|
121
|
+
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
|
|
122
|
+
return result_df
|
|
123
|
+
except ValueError as e:
|
|
124
|
+
logger.error('result dataframe is not found.')
|
|
125
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
126
|
+
except IndexError as e:
|
|
127
|
+
logger.error('result dataframe elements can not be access.')
|
|
128
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
129
|
+
finally:
|
|
130
|
+
lock.release()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def check_accuracy(cos, max_abs_err):
|
|
134
|
+
if cos == CompareConst.SHAPE_UNMATCH:
|
|
135
|
+
return CompareConst.ACCURACY_CHECK_UNMATCH
|
|
136
|
+
if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
|
|
137
|
+
return CompareConst.NONE
|
|
138
|
+
if cos == "N/A" or max_abs_err == "N/A":
|
|
139
|
+
return CompareConst.ACCURACY_CHECK_NO
|
|
140
|
+
try:
|
|
141
|
+
cos, max_abs_err = float(cos), float(max_abs_err)
|
|
142
|
+
except ValueError:
|
|
143
|
+
logger.warning("Cosine or MaxAbsErr can not get float value.")
|
|
144
|
+
return CompareConst.NONE
|
|
145
|
+
if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
|
|
146
|
+
return CompareConst.ACCURACY_CHECK_NO
|
|
147
|
+
if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
|
|
148
|
+
return CompareConst.ACCURACY_CHECK_NO
|
|
149
|
+
return CompareConst.ACCURACY_CHECK_YES
|
|
@@ -2,10 +2,10 @@ import abc
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
from msprobe.core.common.utils import format_value
|
|
4
4
|
from msprobe.core.common.const import Const, CompareConst
|
|
5
|
-
from msprobe.
|
|
5
|
+
from msprobe.core.common.log import logger
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def
|
|
8
|
+
def handle_inf_nan(n_value, b_value):
|
|
9
9
|
"""处理inf和nan的数据"""
|
|
10
10
|
n_inf = np.isinf(n_value)
|
|
11
11
|
b_inf = np.isinf(b_value)
|
|
@@ -54,7 +54,7 @@ def reshape_value(n_value, b_value):
|
|
|
54
54
|
return n_value, b_value
|
|
55
55
|
|
|
56
56
|
|
|
57
|
-
def get_error_message(n_value, b_value,
|
|
57
|
+
def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
|
|
58
58
|
"""获取异常情况的错误信息"""
|
|
59
59
|
if error_flag:
|
|
60
60
|
if n_value == CompareConst.READ_NONE:
|
|
@@ -71,11 +71,62 @@ def get_error_message(n_value, b_value, op_name, error_flag, error_file=None):
|
|
|
71
71
|
if not n_value.shape:
|
|
72
72
|
return "This is type of scalar data, can not compare."
|
|
73
73
|
if n_value.dtype != b_value.dtype:
|
|
74
|
-
logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(
|
|
74
|
+
logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
|
|
75
75
|
return "Dtype of NPU and bench Tensor do not match."
|
|
76
76
|
return ""
|
|
77
77
|
|
|
78
78
|
|
|
79
|
+
def npy_data_check(n_value, b_value):
|
|
80
|
+
error_message = ""
|
|
81
|
+
if n_value is None or b_value is None:
|
|
82
|
+
error_message += "Dump file not found.\n"
|
|
83
|
+
if n_value == "" or b_value == "":
|
|
84
|
+
error_message += "Dump file not found.\n"
|
|
85
|
+
|
|
86
|
+
# 检查 n_value 和 b_value 是否为空
|
|
87
|
+
if not error_message and (n_value.size == 0 or b_value.size == 0):
|
|
88
|
+
error_message += "This is empty data, can not compare.\n"
|
|
89
|
+
|
|
90
|
+
if not error_message:
|
|
91
|
+
if not n_value.shape or not b_value.shape:
|
|
92
|
+
error_message += "This is type of scalar data, can not compare.\n"
|
|
93
|
+
if n_value.shape != b_value.shape:
|
|
94
|
+
error_message += "Shape of NPU and bench Tensor do not match.\n"
|
|
95
|
+
if n_value.dtype != b_value.dtype:
|
|
96
|
+
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
97
|
+
|
|
98
|
+
if not error_message:
|
|
99
|
+
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
|
|
100
|
+
if CompareConst.NAN in (n_value, b_value):
|
|
101
|
+
error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
|
|
102
|
+
if error_message == "":
|
|
103
|
+
error_flag = False
|
|
104
|
+
else:
|
|
105
|
+
error_flag = True
|
|
106
|
+
return error_flag, error_message
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def statistics_data_check(result_dict):
|
|
110
|
+
error_message = ""
|
|
111
|
+
|
|
112
|
+
if result_dict.get(CompareConst.NPU_NAME) is None or result_dict.get(CompareConst.BENCH_NAME) is None:
|
|
113
|
+
error_message += "Dump file not found.\n"
|
|
114
|
+
|
|
115
|
+
if not result_dict.get(CompareConst.NPU_SHAPE) or not result_dict.get(CompareConst.BENCH_SHAPE):
|
|
116
|
+
error_message += "This is type of scalar data, can not compare.\n"
|
|
117
|
+
elif result_dict.get(CompareConst.NPU_SHAPE) != result_dict.get(CompareConst.BENCH_SHAPE):
|
|
118
|
+
error_message += "Tensor shapes do not match.\n"
|
|
119
|
+
|
|
120
|
+
if result_dict.get(CompareConst.NPU_DTYPE) != result_dict.get(CompareConst.BENCH_DTYPE):
|
|
121
|
+
error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
|
|
122
|
+
|
|
123
|
+
if error_message == "":
|
|
124
|
+
error_flag = False
|
|
125
|
+
else:
|
|
126
|
+
error_flag = True
|
|
127
|
+
return error_flag, error_message
|
|
128
|
+
|
|
129
|
+
|
|
79
130
|
class TensorComparisonBasic(abc.ABC):
|
|
80
131
|
"""NPU和bench中npy数据的比较模板"""
|
|
81
132
|
@abc.abstractmethod
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import numpy as np
|
|
5
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
6
|
+
from msprobe.core.common.utils import CompareException, check_file_or_directory_path, check_regex_prefix_format_valid, logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def extract_json(dirname, stack_json=False):
|
|
10
|
+
json_path = ''
|
|
11
|
+
for fname in os.listdir(dirname):
|
|
12
|
+
if fname == "construct.json":
|
|
13
|
+
continue
|
|
14
|
+
full_path = os.path.join(dirname, fname)
|
|
15
|
+
if full_path.endswith('.json'):
|
|
16
|
+
json_path = full_path
|
|
17
|
+
if not stack_json and 'stack' not in json_path:
|
|
18
|
+
break
|
|
19
|
+
if stack_json and 'stack' in json_path:
|
|
20
|
+
break
|
|
21
|
+
|
|
22
|
+
# Provide robustness on invalid directory inputs
|
|
23
|
+
if not json_path:
|
|
24
|
+
logger.error(f'No file is found in dump dir {dirname}. ')
|
|
25
|
+
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
26
|
+
return json_path
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def check_and_return_dir_contents(dump_dir, prefix):
|
|
30
|
+
"""
|
|
31
|
+
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
32
|
+
pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
dump_dir (str): dump dir
|
|
36
|
+
prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
content [list]: dir contents
|
|
40
|
+
Raises:
|
|
41
|
+
CompareException: invalid path
|
|
42
|
+
ValueError: prefix not match the patterns
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
check_regex_prefix_format_valid(prefix)
|
|
46
|
+
check_file_or_directory_path(dump_dir, True)
|
|
47
|
+
contents = os.listdir(dump_dir)
|
|
48
|
+
pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
|
|
49
|
+
for name in contents:
|
|
50
|
+
if not pattern.match(name):
|
|
51
|
+
logger.error(
|
|
52
|
+
f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
|
|
53
|
+
f"output. Please check and delete irrelevant files in {dump_dir} and try again."
|
|
54
|
+
)
|
|
55
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
56
|
+
return contents
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def rename_api(npu_name, process):
|
|
60
|
+
npu_split = npu_name.split(process)
|
|
61
|
+
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
62
|
+
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
63
|
+
torch_func = str(torch_func_split[0]) + str(in_out)
|
|
64
|
+
return torch_func
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def read_op(op_data, op_name):
|
|
68
|
+
op_parsed_list = Const.DEFAULT_LIST
|
|
69
|
+
if Const.FORWARD in op_name:
|
|
70
|
+
if Const.INPUT_ARGS in op_data:
|
|
71
|
+
input_item = op_data[Const.INPUT_ARGS]
|
|
72
|
+
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
73
|
+
op_parsed_list = input_parsed_list.copy()
|
|
74
|
+
input_parsed_list.clear()
|
|
75
|
+
if Const.INPUT_KWARGS in op_data:
|
|
76
|
+
kwargs_item = op_data[Const.INPUT_KWARGS]
|
|
77
|
+
if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
|
|
78
|
+
kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
|
|
79
|
+
op_parsed_list += kwarg_parsed_list
|
|
80
|
+
kwarg_parsed_list.clear()
|
|
81
|
+
elif kwargs_item:
|
|
82
|
+
for kwarg in kwargs_item:
|
|
83
|
+
kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
|
|
84
|
+
op_parsed_list += kwarg_parsed_list
|
|
85
|
+
kwarg_parsed_list.clear()
|
|
86
|
+
if Const.OUTPUT in op_data:
|
|
87
|
+
output_item = op_data[Const.OUTPUT]
|
|
88
|
+
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
89
|
+
op_parsed_list += output_parsed_list
|
|
90
|
+
output_parsed_list.clear()
|
|
91
|
+
if Const.BACKWARD in op_name:
|
|
92
|
+
if Const.INPUT in op_data:
|
|
93
|
+
input_item = op_data[Const.INPUT]
|
|
94
|
+
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
95
|
+
op_parsed_list = input_parsed_list.copy()
|
|
96
|
+
input_parsed_list.clear()
|
|
97
|
+
if Const.OUTPUT in op_data:
|
|
98
|
+
output_item = op_data[Const.OUTPUT]
|
|
99
|
+
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
100
|
+
op_parsed_list += output_parsed_list
|
|
101
|
+
output_parsed_list.clear()
|
|
102
|
+
return op_parsed_list
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
|
|
106
|
+
if item_list is None:
|
|
107
|
+
item_list = []
|
|
108
|
+
if item is None or (isinstance(item, dict) and not item):
|
|
109
|
+
if not top_bool:
|
|
110
|
+
tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
|
|
111
|
+
'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
|
|
112
|
+
else:
|
|
113
|
+
tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
|
|
114
|
+
'shape': None, 'md5': None, 'data_name': '-1'}
|
|
115
|
+
item_list.append(tmp)
|
|
116
|
+
return item_list
|
|
117
|
+
if index is None:
|
|
118
|
+
if isinstance(item, dict):
|
|
119
|
+
full_op_name = op_name + '.0'
|
|
120
|
+
else:
|
|
121
|
+
full_op_name = op_name
|
|
122
|
+
else:
|
|
123
|
+
full_op_name = op_name + Const.SEP + str(index)
|
|
124
|
+
if isinstance(item, dict):
|
|
125
|
+
if 'type' not in item:
|
|
126
|
+
for kwarg in item:
|
|
127
|
+
kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
|
|
128
|
+
item_list += kwarg_parsed_list
|
|
129
|
+
kwarg_parsed_list.clear()
|
|
130
|
+
elif 'dtype' in item:
|
|
131
|
+
parsed_item = item
|
|
132
|
+
parsed_item['full_op_name'] = full_op_name
|
|
133
|
+
item_list.append(parsed_item)
|
|
134
|
+
elif 'type' in item:
|
|
135
|
+
parsed_item = {}
|
|
136
|
+
if item['type'] == 'torch.Size':
|
|
137
|
+
parsed_item['full_op_name'] = full_op_name
|
|
138
|
+
parsed_item['dtype'] = 'torch.Size'
|
|
139
|
+
parsed_item['shape'] = str(item['value'])
|
|
140
|
+
parsed_item['md5'] = None
|
|
141
|
+
parsed_item['Max'] = None
|
|
142
|
+
parsed_item['Min'] = None
|
|
143
|
+
parsed_item['Mean'] = None
|
|
144
|
+
parsed_item['Norm'] = None
|
|
145
|
+
parsed_item['data_name'] = '-1'
|
|
146
|
+
item_list.append(parsed_item)
|
|
147
|
+
elif item['type'] == 'slice':
|
|
148
|
+
parsed_item['full_op_name'] = full_op_name
|
|
149
|
+
parsed_item['dtype'] = 'slice'
|
|
150
|
+
parsed_item['shape'] = str(np.shape(np.array(item['value'])))
|
|
151
|
+
parsed_item['md5'] = None
|
|
152
|
+
parsed_item['Max'] = None
|
|
153
|
+
parsed_item['Min'] = None
|
|
154
|
+
parsed_item['Mean'] = None
|
|
155
|
+
parsed_item['Norm'] = None
|
|
156
|
+
parsed_item['data_name'] = '-1'
|
|
157
|
+
item_list.append(parsed_item)
|
|
158
|
+
else:
|
|
159
|
+
parsed_item['full_op_name'] = full_op_name
|
|
160
|
+
parsed_item['dtype'] = str(type(item['value']))
|
|
161
|
+
parsed_item['shape'] = '[]'
|
|
162
|
+
parsed_item['md5'] = None
|
|
163
|
+
parsed_item['Max'] = item['value']
|
|
164
|
+
parsed_item['Min'] = item['value']
|
|
165
|
+
parsed_item['Mean'] = item['value']
|
|
166
|
+
parsed_item['Norm'] = item['value']
|
|
167
|
+
parsed_item['data_name'] = '-1'
|
|
168
|
+
item_list.append(parsed_item)
|
|
169
|
+
else:
|
|
170
|
+
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
171
|
+
else:
|
|
172
|
+
for j, item_spec in enumerate(item):
|
|
173
|
+
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
|
|
174
|
+
return item_list
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
178
|
+
"""
|
|
179
|
+
Function Description:
|
|
180
|
+
解析下面格式的数据, 是api参数的一种特殊格式
|
|
181
|
+
{
|
|
182
|
+
"last_hidden_state": {
|
|
183
|
+
"type": "torch.Tensor",
|
|
184
|
+
"dtype": "torch.bfloat16",
|
|
185
|
+
...
|
|
186
|
+
},
|
|
187
|
+
"loss": {
|
|
188
|
+
"type": "torch.Tensor",
|
|
189
|
+
"dtype": "torch.float32",
|
|
190
|
+
...
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
Parameter:
|
|
194
|
+
data_dict: 字典格式的数据
|
|
195
|
+
full_op_name: 参数的全名字符串
|
|
196
|
+
item_list: 参数信息集合
|
|
197
|
+
"""
|
|
198
|
+
for key, value in data_dict.items():
|
|
199
|
+
if isinstance(value, dict):
|
|
200
|
+
parsed_item = value
|
|
201
|
+
parts = full_op_name.split(Const.SEP)
|
|
202
|
+
parts.insert(-1, key)
|
|
203
|
+
full_op_name_new = ".".join(parts)
|
|
204
|
+
parsed_item['full_op_name'] = full_op_name_new
|
|
205
|
+
item_list.append(parsed_item)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
|
|
209
|
+
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
210
|
+
min_len = min(n_len, b_len)
|
|
211
|
+
npu_stack_info = n_dict.get("stack_info", None)
|
|
212
|
+
bench_stack_info = b_dict.get("stack_info", None)
|
|
213
|
+
has_stack = npu_stack_info and bench_stack_info
|
|
214
|
+
|
|
215
|
+
all_mode_bool = not (summary_compare or md5_compare)
|
|
216
|
+
if all_mode_bool:
|
|
217
|
+
npu_data_name = n_dict.get("data_name", None)
|
|
218
|
+
bench_data_name = b_dict.get("data_name", None)
|
|
219
|
+
|
|
220
|
+
for index in range(min_len):
|
|
221
|
+
|
|
222
|
+
n_name = n_dict['op_name'][n_start + index]
|
|
223
|
+
b_name = b_dict['op_name'][b_start + index]
|
|
224
|
+
n_struct = n_dict[key][index]
|
|
225
|
+
b_struct = b_dict[key][index]
|
|
226
|
+
err_msg = ""
|
|
227
|
+
if md5_compare:
|
|
228
|
+
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
229
|
+
n_struct[2], b_struct[2],
|
|
230
|
+
CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
|
|
231
|
+
if has_stack and index == 0 and key == "input_struct":
|
|
232
|
+
result_item.extend(npu_stack_info)
|
|
233
|
+
else:
|
|
234
|
+
result_item.append(CompareConst.NONE)
|
|
235
|
+
result.append(result_item)
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
if summary_compare:
|
|
239
|
+
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
240
|
+
" ", " ", " ", " ", " ", " ", " ", " "]
|
|
241
|
+
else:
|
|
242
|
+
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
243
|
+
" ", " ", " ", " ", " "]
|
|
244
|
+
|
|
245
|
+
npu_summary_data = n_dict.get("summary")[n_start + index]
|
|
246
|
+
result_item.extend(npu_summary_data)
|
|
247
|
+
bench_summary_data = b_dict.get("summary")[b_start + index]
|
|
248
|
+
result_item.extend(bench_summary_data)
|
|
249
|
+
|
|
250
|
+
if summary_compare:
|
|
251
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
252
|
+
warning_flag = False
|
|
253
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
254
|
+
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
255
|
+
diff = npu_val - bench_val
|
|
256
|
+
if bench_val != 0:
|
|
257
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
258
|
+
else:
|
|
259
|
+
relative = "N/A"
|
|
260
|
+
result_item[start_idx + i] = diff
|
|
261
|
+
result_item[start_idx + i + 4] = relative
|
|
262
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
263
|
+
if magnitude_diff > 0.5:
|
|
264
|
+
warning_flag = True
|
|
265
|
+
else:
|
|
266
|
+
result_item[start_idx + i] = CompareConst.NONE
|
|
267
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
268
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
269
|
+
for i in range(start_idx, len(result_item)):
|
|
270
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
271
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
272
|
+
|
|
273
|
+
result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
|
|
274
|
+
result_item.append(err_msg)
|
|
275
|
+
if has_stack and index == 0 and key == "input_struct":
|
|
276
|
+
result_item.extend(npu_stack_info)
|
|
277
|
+
else:
|
|
278
|
+
result_item.append(CompareConst.NONE)
|
|
279
|
+
if all_mode_bool:
|
|
280
|
+
result_item.append(npu_data_name[n_start + index])
|
|
281
|
+
|
|
282
|
+
result.append(result_item)
|
|
283
|
+
|
|
284
|
+
if n_len > b_len:
|
|
285
|
+
for index in range(b_len, n_len):
|
|
286
|
+
n_name = n_dict['op_name'][n_start + index]
|
|
287
|
+
n_struct = n_dict[key][index]
|
|
288
|
+
if md5_compare:
|
|
289
|
+
result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
|
|
290
|
+
n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
|
|
291
|
+
result.append(result_item)
|
|
292
|
+
continue
|
|
293
|
+
result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
|
|
294
|
+
n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
|
|
295
|
+
summary_data = n_dict.get("summary")[n_start + index]
|
|
296
|
+
result_item.extend(summary_data)
|
|
297
|
+
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
|
|
298
|
+
result_item.extend(summary_data)
|
|
299
|
+
|
|
300
|
+
err_msg = ""
|
|
301
|
+
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
302
|
+
result_item.append(err_msg)
|
|
303
|
+
|
|
304
|
+
if has_stack and index == 0 and key == "input_struct":
|
|
305
|
+
result_item.extend(npu_stack_info)
|
|
306
|
+
else:
|
|
307
|
+
result_item.append(CompareConst.NONE)
|
|
308
|
+
if all_mode_bool:
|
|
309
|
+
result_item.append(npu_data_name[n_start + index])
|
|
310
|
+
|
|
311
|
+
result.append(result_item)
|
|
312
|
+
|
|
313
|
+
n_num = len(n_dict['op_name'])
|
|
314
|
+
b_num = len(b_dict['op_name'])
|
|
315
|
+
n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
|
|
316
|
+
b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
|
|
317
|
+
n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
|
|
318
|
+
b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
|
|
319
|
+
n_num_output = n_num - n_num_input - n_num_kwarg
|
|
320
|
+
b_num_output = b_num - b_num_input - b_num_kwarg
|
|
321
|
+
get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
|
|
322
|
+
get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
|
|
323
|
+
get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
327
|
+
index_out = 0
|
|
328
|
+
npu_stack_info = n_dict.get("stack_info", None)
|
|
329
|
+
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
330
|
+
err_msg = CompareConst.NO_BENCH
|
|
331
|
+
accuracy_check_res = CompareConst.N_A
|
|
332
|
+
for index, n_name in enumerate(n_dict["op_name"]):
|
|
333
|
+
if n_name.find("input") != -1:
|
|
334
|
+
n_struct = n_dict["input_struct"][index]
|
|
335
|
+
else:
|
|
336
|
+
n_struct = n_dict["output_struct"][index_out]
|
|
337
|
+
index_out += 1
|
|
338
|
+
|
|
339
|
+
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
340
|
+
if md5_compare:
|
|
341
|
+
result_item.extend([CompareConst.N_A] * 3)
|
|
342
|
+
if npu_stack_info and index == 0:
|
|
343
|
+
result_item.extend(npu_stack_info)
|
|
344
|
+
else:
|
|
345
|
+
result_item.append(CompareConst.NONE)
|
|
346
|
+
result.append(result_item)
|
|
347
|
+
continue
|
|
348
|
+
if summary_compare:
|
|
349
|
+
result_item.extend([CompareConst.N_A] * 8)
|
|
350
|
+
else:
|
|
351
|
+
result_item.extend([CompareConst.N_A] * 5)
|
|
352
|
+
npu_summary_data = n_dict.get("summary")[index]
|
|
353
|
+
result_item.extend(npu_summary_data)
|
|
354
|
+
bench_summary_data = [CompareConst.N_A] * 4
|
|
355
|
+
result_item.extend(bench_summary_data)
|
|
356
|
+
result_item.append(accuracy_check_res)
|
|
357
|
+
result_item.append(err_msg)
|
|
358
|
+
if npu_stack_info and index == 0:
|
|
359
|
+
result_item.extend(npu_stack_info)
|
|
360
|
+
else:
|
|
361
|
+
result_item.append(CompareConst.NONE)
|
|
362
|
+
if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
|
|
363
|
+
result_item.extend(["-1"])
|
|
364
|
+
result.append(result_item)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
368
|
+
op_dict = {}
|
|
369
|
+
op_dict["op_name"] = []
|
|
370
|
+
op_dict["input_struct"] = []
|
|
371
|
+
op_dict["kwargs_struct"] = []
|
|
372
|
+
op_dict["output_struct"] = []
|
|
373
|
+
op_dict["summary"] = []
|
|
374
|
+
op_dict["stack_info"] = []
|
|
375
|
+
|
|
376
|
+
all_mode_bool = not (summary_compare or md5_compare)
|
|
377
|
+
if all_mode_bool:
|
|
378
|
+
op_dict["data_name"] = []
|
|
379
|
+
|
|
380
|
+
for tensor in tensor_list:
|
|
381
|
+
if len(tensor) == 2:
|
|
382
|
+
op_dict['stack_info'].append(tensor['full_info'])
|
|
383
|
+
break
|
|
384
|
+
op_dict["op_name"].append(tensor['full_op_name'])
|
|
385
|
+
if not md5_compare:
|
|
386
|
+
if tensor['full_op_name'].find("input") != -1:
|
|
387
|
+
op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
|
|
388
|
+
elif tensor['full_op_name'].find("kwarg") != -1:
|
|
389
|
+
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
|
|
390
|
+
elif tensor['full_op_name'].find("output") != -1:
|
|
391
|
+
op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
|
|
392
|
+
else:
|
|
393
|
+
if tensor['full_op_name'].find("input") != -1:
|
|
394
|
+
op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
395
|
+
elif tensor['full_op_name'].find("kwarg") != -1:
|
|
396
|
+
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
397
|
+
elif tensor['full_op_name'].find("output") != -1:
|
|
398
|
+
op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
399
|
+
|
|
400
|
+
op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
|
|
401
|
+
|
|
402
|
+
if all_mode_bool:
|
|
403
|
+
op_dict["data_name"].append(tensor['data_name'])
|
|
404
|
+
|
|
405
|
+
if not op_dict["kwargs_struct"]:
|
|
406
|
+
del op_dict["kwargs_struct"]
|
|
407
|
+
return op_dict if op_dict["op_name"] else {}
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _compare_parser(parser):
|
|
411
|
+
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
412
|
+
help="<Required> The compare input path, a dict json.", required=True)
|
|
413
|
+
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
414
|
+
help="<Required> The compare task result out path.", required=True)
|
|
415
|
+
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
416
|
+
help="<optional> Whether to save stack info.", required=False)
|
|
417
|
+
parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
|
|
418
|
+
help="<optional> Whether to give advisor.", required=False)
|
|
419
|
+
parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
|
|
420
|
+
help="<optional> Whether to perform a fuzzy match on the api name.", required=False)
|
|
421
|
+
parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
|
|
422
|
+
help="<optional> The cell mapping file path.", required=False)
|
|
423
|
+
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
424
|
+
help="<optional> The api mapping file path.", required=False)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
|