mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,581 +1,592 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import os
|
|
3
|
-
import csv
|
|
4
|
-
import sys
|
|
5
|
-
import time
|
|
6
|
-
import gc
|
|
7
|
-
from collections import namedtuple
|
|
8
|
-
|
|
9
|
-
try:
|
|
10
|
-
import torch_npu
|
|
11
|
-
except ImportError:
|
|
12
|
-
is_gpu = True
|
|
13
|
-
current_device = "cuda"
|
|
14
|
-
else:
|
|
15
|
-
is_gpu = False
|
|
16
|
-
current_device = "npu"
|
|
17
|
-
import torch
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
-
|
|
20
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
|
|
21
|
-
get_validated_result_csv_path, get_validated_details_csv_path, exec_api
|
|
22
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
23
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
24
|
-
initialize_save_path, UtDataProcessor
|
|
25
|
-
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
26
|
-
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
28
|
-
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
29
|
-
from msprobe.core.common.
|
|
30
|
-
change_mode, check_path_before_create, create_directory
|
|
31
|
-
from msprobe.pytorch.common.log import logger
|
|
32
|
-
from msprobe.
|
|
33
|
-
from msprobe.
|
|
34
|
-
from msprobe.
|
|
35
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
'
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
torch.
|
|
55
|
-
torch.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
'
|
|
61
|
-
'
|
|
62
|
-
'
|
|
63
|
-
'
|
|
64
|
-
'
|
|
65
|
-
'
|
|
66
|
-
'
|
|
67
|
-
'
|
|
68
|
-
'
|
|
69
|
-
'
|
|
70
|
-
'
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
arg_in
|
|
105
|
-
arg_in
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
arg_in
|
|
129
|
-
arg_in
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
need_raise_dtypes =
|
|
154
|
-
need_raise_dtypes
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
logger.info(f"UT task
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
logger.info(f"UT task
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
csv_reader
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
change_mode(
|
|
190
|
-
|
|
191
|
-
logger.info(f"UT task
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
198
|
-
if api_full_name in api_name_set:
|
|
199
|
-
continue
|
|
200
|
-
if is_unsupported_api(api_full_name):
|
|
201
|
-
continue
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
if api_data
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
if
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
if
|
|
297
|
-
|
|
298
|
-
return
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
need_backward =
|
|
329
|
-
if
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
if
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import csv
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
import gc
|
|
7
|
+
from collections import namedtuple
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import torch_npu
|
|
11
|
+
except ImportError:
|
|
12
|
+
is_gpu = True
|
|
13
|
+
current_device = "cuda"
|
|
14
|
+
else:
|
|
15
|
+
is_gpu = False
|
|
16
|
+
current_device = "npu"
|
|
17
|
+
import torch
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
|
|
21
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api
|
|
22
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
24
|
+
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
25
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
28
|
+
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
29
|
+
from msprobe.core.common.file_utils import FileOpen, FileChecker, \
|
|
30
|
+
change_mode, check_path_before_create, create_directory, get_json_contents
|
|
31
|
+
from msprobe.pytorch.common.log import logger
|
|
32
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
33
|
+
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
34
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
39
|
+
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
|
|
40
|
+
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
41
|
+
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
42
|
+
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
43
|
+
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
44
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
45
|
+
|
|
46
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
47
|
+
|
|
48
|
+
not_backward_list = ['repeat_interleave']
|
|
49
|
+
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
50
|
+
not_raise_dtype_set = {'type_as'}
|
|
51
|
+
|
|
52
|
+
RAISE_PRECISION = {
|
|
53
|
+
torch.float16: torch.float32,
|
|
54
|
+
torch.bfloat16: torch.float32,
|
|
55
|
+
torch.float32: torch.float64
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
tqdm_params = {
|
|
59
|
+
'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
|
|
60
|
+
'desc': 'Processing', # 进度条前的描述文字
|
|
61
|
+
'leave': True, # 迭代完成后保留进度条的显示
|
|
62
|
+
'ncols': 75, # 进度条的固定宽度
|
|
63
|
+
'mininterval': 0.1, # 更新进度条的最小间隔秒数
|
|
64
|
+
'maxinterval': 1.0, # 更新进度条的最大间隔秒数
|
|
65
|
+
'miniters': 1, # 更新进度条之间的最小迭代次数
|
|
66
|
+
'ascii': None, # 根据环境自动使用ASCII或Unicode字符
|
|
67
|
+
'unit': 'it', # 迭代单位
|
|
68
|
+
'unit_scale': True, # 自动根据单位缩放
|
|
69
|
+
'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台
|
|
70
|
+
'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def deal_detach(arg, to_detach=True):
|
|
75
|
+
return arg.detach() if to_detach else arg
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
|
|
79
|
+
'''
|
|
80
|
+
将标杆数据的dtype转换为raise_dtype
|
|
81
|
+
输入:
|
|
82
|
+
api_name:api名称
|
|
83
|
+
arg:标杆输入
|
|
84
|
+
raise_dtype:需要转换的dtype
|
|
85
|
+
输出:
|
|
86
|
+
arg: 转换dtype的标杆输入
|
|
87
|
+
'''
|
|
88
|
+
if api_name in hf_32_standard_api and arg.dtype == torch.float32:
|
|
89
|
+
return arg
|
|
90
|
+
if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
|
|
91
|
+
return arg
|
|
92
|
+
return arg.type(raise_dtype)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def generate_device_params(input_args, input_kwargs, need_backward, api_name):
|
|
96
|
+
def recursive_arg_to_device(arg_in, to_detach):
|
|
97
|
+
if isinstance(arg_in, (list, tuple)):
|
|
98
|
+
return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
|
|
99
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
100
|
+
if need_backward and arg_in.requires_grad:
|
|
101
|
+
arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
|
|
102
|
+
temp_arg_in = arg_in * 1
|
|
103
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
104
|
+
arg_in.retain_grad()
|
|
105
|
+
return arg_in
|
|
106
|
+
else:
|
|
107
|
+
return deal_detach(arg_in.clone(), to_detach).to(current_device)
|
|
108
|
+
else:
|
|
109
|
+
return arg_in
|
|
110
|
+
|
|
111
|
+
is_detach = api_name not in not_detach_set
|
|
112
|
+
device_args = recursive_arg_to_device(input_args, is_detach)
|
|
113
|
+
device_kwargs = \
|
|
114
|
+
{key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
|
|
115
|
+
return device_args, device_kwargs
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
119
|
+
def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
|
|
120
|
+
if isinstance(arg_in, (list, tuple)):
|
|
121
|
+
return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
|
|
122
|
+
elif isinstance(arg_in, torch.Tensor):
|
|
123
|
+
if need_backward and arg_in.requires_grad:
|
|
124
|
+
arg_in = deal_detach(raise_bench_data_dtype(
|
|
125
|
+
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
126
|
+
temp_arg_in = arg_in * 1
|
|
127
|
+
arg_in = temp_arg_in.type_as(arg_in)
|
|
128
|
+
arg_in.retain_grad()
|
|
129
|
+
return arg_in
|
|
130
|
+
else:
|
|
131
|
+
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
|
|
132
|
+
else:
|
|
133
|
+
return arg_in
|
|
134
|
+
|
|
135
|
+
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
|
|
136
|
+
if arg_in.dtype in RAISE_PRECISION:
|
|
137
|
+
return True
|
|
138
|
+
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
|
|
139
|
+
return True
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
|
|
143
|
+
if isinstance(arg_in, (list, tuple)):
|
|
144
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
|
|
145
|
+
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
146
|
+
return set([arg_in.dtype])
|
|
147
|
+
elif isinstance(arg_in, dict) and check_kwargs:
|
|
148
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
|
|
149
|
+
return set()
|
|
150
|
+
|
|
151
|
+
raise_dtype = None
|
|
152
|
+
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
153
|
+
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
154
|
+
if len(need_raise_dtypes) == 1:
|
|
155
|
+
raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
|
|
156
|
+
elif len(need_raise_dtypes) >= 2:
|
|
157
|
+
raise_dtype = torch.float32
|
|
158
|
+
|
|
159
|
+
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
160
|
+
is_detach = api_name not in not_detach_set
|
|
161
|
+
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
162
|
+
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
|
|
163
|
+
return cpu_args, cpu_kwargs
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def run_ut(config):
|
|
167
|
+
logger.info("start UT test")
|
|
168
|
+
if config.online_config.is_online:
|
|
169
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
|
|
170
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
|
|
171
|
+
else:
|
|
172
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
173
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
174
|
+
|
|
175
|
+
if config.save_error_data:
|
|
176
|
+
logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
|
|
177
|
+
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
178
|
+
|
|
179
|
+
if config.online_config.is_online:
|
|
180
|
+
run_api_online(config, compare)
|
|
181
|
+
else:
|
|
182
|
+
with FileOpen(config.result_csv_path, 'r') as file:
|
|
183
|
+
csv_reader = csv.reader(file)
|
|
184
|
+
next(csv_reader)
|
|
185
|
+
api_name_set = {row[0] for row in csv_reader}
|
|
186
|
+
run_api_offline(config, compare, api_name_set)
|
|
187
|
+
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
188
|
+
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
189
|
+
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
190
|
+
logger.info(f"UT task result csv is saved in {result_csv_path}")
|
|
191
|
+
logger.info(f"UT task details csv is saved in {details_csv_path}")
|
|
192
|
+
compare.print_pretest_result()
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def run_api_offline(config, compare, api_name_set):
|
|
196
|
+
err_column = CompareColumn()
|
|
197
|
+
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
198
|
+
if api_full_name in api_name_set:
|
|
199
|
+
continue
|
|
200
|
+
if is_unsupported_api(api_full_name):
|
|
201
|
+
continue
|
|
202
|
+
_, api_name = extract_basic_api_segments(api_full_name)
|
|
203
|
+
if not api_name:
|
|
204
|
+
err_message = f"API {api_full_name} not support for run ut. SKIP."
|
|
205
|
+
logger.error(err_message)
|
|
206
|
+
fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
|
|
207
|
+
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
|
|
208
|
+
compare.record_results(result_info)
|
|
209
|
+
continue
|
|
210
|
+
try:
|
|
211
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
212
|
+
continue
|
|
213
|
+
data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
|
|
214
|
+
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
|
|
215
|
+
if config.save_error_data:
|
|
216
|
+
do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
|
|
217
|
+
except Exception as err:
|
|
218
|
+
if "expected scalar type Long" in str(err):
|
|
219
|
+
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
220
|
+
f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
221
|
+
else:
|
|
222
|
+
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
223
|
+
fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
224
|
+
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
|
|
225
|
+
compare.record_results(result_info)
|
|
226
|
+
finally:
|
|
227
|
+
if is_gpu:
|
|
228
|
+
torch.cuda.empty_cache()
|
|
229
|
+
else:
|
|
230
|
+
torch.npu.empty_cache()
|
|
231
|
+
gc.collect()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def run_api_online(config, compare):
|
|
235
|
+
attl = init_attl(config.online_config)
|
|
236
|
+
dispatcher = ConsumerDispatcher(compare=compare)
|
|
237
|
+
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
238
|
+
|
|
239
|
+
def tcp_communication_flow():
|
|
240
|
+
while True:
|
|
241
|
+
api_data = attl.recv()
|
|
242
|
+
if api_data == 'STOP_':
|
|
243
|
+
continue
|
|
244
|
+
if api_data == 'KILL_':
|
|
245
|
+
time.sleep(1)
|
|
246
|
+
logger.info("==========接收到STOP信号==========")
|
|
247
|
+
dispatcher.stop()
|
|
248
|
+
attl.stop_serve()
|
|
249
|
+
time.sleep(1)
|
|
250
|
+
break
|
|
251
|
+
if not isinstance(api_data, ApiData):
|
|
252
|
+
continue
|
|
253
|
+
api_full_name = api_data.name
|
|
254
|
+
_, api_name = extract_basic_api_segments(api_full_name)
|
|
255
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
256
|
+
continue
|
|
257
|
+
if api_data.rank in config.online_config.rank_list:
|
|
258
|
+
dispatcher.update_consume_queue(api_data)
|
|
259
|
+
|
|
260
|
+
def shared_storage_communication_flow():
|
|
261
|
+
flag_num = -1
|
|
262
|
+
while True:
|
|
263
|
+
api_data = attl.download()
|
|
264
|
+
if api_data == "start":
|
|
265
|
+
if flag_num == -1:
|
|
266
|
+
flag_num += 1
|
|
267
|
+
flag_num += 1
|
|
268
|
+
if api_data == "end":
|
|
269
|
+
flag_num -= 1
|
|
270
|
+
if flag_num == 0:
|
|
271
|
+
dispatcher.stop()
|
|
272
|
+
break
|
|
273
|
+
if not isinstance(api_data, ApiData):
|
|
274
|
+
continue
|
|
275
|
+
api_full_name = api_data.name
|
|
276
|
+
_, api_name = extract_basic_api_segments(api_full_name)
|
|
277
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
278
|
+
continue
|
|
279
|
+
if api_data.rank in config.online_config.rank_list:
|
|
280
|
+
dispatcher.update_consume_queue(api_data)
|
|
281
|
+
|
|
282
|
+
if config.online_config.nfs_path:
|
|
283
|
+
shared_storage_communication_flow()
|
|
284
|
+
else:
|
|
285
|
+
tcp_communication_flow()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
289
|
+
"""
|
|
290
|
+
run api(api_name) if api_name not in black_list and in white_list.
|
|
291
|
+
If api is both in black_list and black_list, black_list first.
|
|
292
|
+
return: False for exec api, True for not exec
|
|
293
|
+
"""
|
|
294
|
+
if black_list and api_name in black_list:
|
|
295
|
+
return True
|
|
296
|
+
if white_list and api_name not in white_list:
|
|
297
|
+
return True
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def is_unsupported_api(api_name):
|
|
302
|
+
split_name = api_name.split(Const.SEP)[0]
|
|
303
|
+
flag = split_name == Const.DISTRIBUTED
|
|
304
|
+
if flag:
|
|
305
|
+
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
306
|
+
return flag
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
|
|
310
|
+
if not is_fwd_success or not is_bwd_success:
|
|
311
|
+
processor = UtDataProcessor(error_data_path)
|
|
312
|
+
for element in data_info.in_fwd_data_list:
|
|
313
|
+
processor.save_tensors_in_element(api_full_name + '.forward.input', element)
|
|
314
|
+
processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
|
|
315
|
+
processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
|
|
316
|
+
processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
|
|
317
|
+
processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
|
|
318
|
+
processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
|
|
322
|
+
in_fwd_data_list = []
|
|
323
|
+
backward_message = ''
|
|
324
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
325
|
+
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
|
|
326
|
+
in_fwd_data_list.append(args)
|
|
327
|
+
in_fwd_data_list.append(kwargs)
|
|
328
|
+
need_backward = api_full_name in backward_content
|
|
329
|
+
if not need_grad:
|
|
330
|
+
logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
|
|
331
|
+
backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
|
|
332
|
+
if api_name in not_backward_list:
|
|
333
|
+
need_grad = False
|
|
334
|
+
logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
|
|
335
|
+
backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
|
|
336
|
+
need_backward = need_backward and need_grad
|
|
337
|
+
if kwargs.get("device"):
|
|
338
|
+
del kwargs["device"]
|
|
339
|
+
cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
|
|
340
|
+
device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
|
|
341
|
+
bench_grad_out, device_grad_out = None, None
|
|
342
|
+
out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
|
|
343
|
+
device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs)
|
|
344
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
345
|
+
ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
|
|
346
|
+
api_setting_dict = get_json_contents(ut_setting_path)
|
|
347
|
+
grad_input_index = api_setting_dict.get(api_name)
|
|
348
|
+
grad_index = None
|
|
349
|
+
grad, bench_grad = None, None
|
|
350
|
+
if grad_input_index is not None:
|
|
351
|
+
grad_index = grad_input_index.get('grad_index')
|
|
352
|
+
|
|
353
|
+
if need_backward:
|
|
354
|
+
if need_to_backward(grad_index, out):
|
|
355
|
+
backward_args = backward_content[api_full_name].get("input")
|
|
356
|
+
grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
|
|
357
|
+
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
|
|
358
|
+
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
359
|
+
device_grad = grad.clone().detach().to(current_device)
|
|
360
|
+
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
361
|
+
else:
|
|
362
|
+
backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
|
|
363
|
+
if api_name == "npu_fusion_attention":
|
|
364
|
+
out = out[0]
|
|
365
|
+
device_out = device_out[0]
|
|
366
|
+
|
|
367
|
+
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
371
|
+
in_fwd_data_list = []
|
|
372
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
373
|
+
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
374
|
+
in_fwd_data_list.append(args)
|
|
375
|
+
in_fwd_data_list.append(kwargs)
|
|
376
|
+
if kwargs.get("device"):
|
|
377
|
+
del kwargs["device"]
|
|
378
|
+
|
|
379
|
+
device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
|
|
380
|
+
device_out = move2device_exec(device_out, "cpu")
|
|
381
|
+
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def get_api_info(api_info_dict, api_name, real_data_path):
|
|
385
|
+
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
386
|
+
need_grad = True
|
|
387
|
+
if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
|
|
388
|
+
need_grad = False
|
|
389
|
+
args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
|
|
390
|
+
return args, kwargs, need_grad
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def need_to_backward(grad_index, out):
|
|
394
|
+
if grad_index is None and isinstance(out, (list, tuple)):
|
|
395
|
+
return False
|
|
396
|
+
return True
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def run_backward(args, grad, grad_index, out):
|
|
400
|
+
if grad_index is not None:
|
|
401
|
+
out[grad_index].backward(grad)
|
|
402
|
+
else:
|
|
403
|
+
out.backward(grad)
|
|
404
|
+
args_grad = []
|
|
405
|
+
for arg in args:
|
|
406
|
+
if isinstance(arg, torch.Tensor):
|
|
407
|
+
args_grad.append(arg.grad)
|
|
408
|
+
grad_out = args_grad
|
|
409
|
+
|
|
410
|
+
return grad_out
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def initialize_save_error_data(error_data_path):
|
|
414
|
+
check_path_before_create(error_data_path)
|
|
415
|
+
create_directory(error_data_path)
|
|
416
|
+
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
417
|
+
ability=FileCheckConst.WRITE_ABLE)
|
|
418
|
+
error_data_path = error_data_path_checker.common_check()
|
|
419
|
+
error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
420
|
+
return error_data_path
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def init_attl(config):
|
|
424
|
+
"""config: OnlineConfig"""
|
|
425
|
+
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
426
|
+
connect_ip=config.host,
|
|
427
|
+
connect_port=config.port,
|
|
428
|
+
nfs_path=config.nfs_path,
|
|
429
|
+
tls_path=config.tls_path))
|
|
430
|
+
return attl
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _run_ut_parser(parser):
|
|
434
|
+
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
435
|
+
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
436
|
+
"a json file.",
|
|
437
|
+
required=False)
|
|
438
|
+
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
439
|
+
help="<optional> The ut task result out path.",
|
|
440
|
+
required=False)
|
|
441
|
+
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
442
|
+
help="<optional> Save compare failed api output.", required=False)
|
|
443
|
+
parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true",
|
|
444
|
+
help="<optional> whether to turn on jit compile", required=False)
|
|
445
|
+
|
|
446
|
+
class UniqueDeviceAction(argparse.Action):
|
|
447
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
448
|
+
unique_values = set(values)
|
|
449
|
+
if len(values) != len(unique_values):
|
|
450
|
+
parser.error("device id must be unique")
|
|
451
|
+
for device_id in values:
|
|
452
|
+
if not 0 <= device_id:
|
|
453
|
+
parser.error("device id must be greater than or equal to 0")
|
|
454
|
+
setattr(namespace, self.dest, values)
|
|
455
|
+
|
|
456
|
+
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
457
|
+
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
458
|
+
default=[0], required=False, action=UniqueDeviceAction)
|
|
459
|
+
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str,
|
|
460
|
+
help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
|
|
461
|
+
"when run ut is interrupted, enter the file path to continue run ut.",
|
|
462
|
+
required=False)
|
|
463
|
+
parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
|
|
464
|
+
help="<optional> Whether to filter the api in the api_info_file.", required=False)
|
|
465
|
+
parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
|
|
466
|
+
help="<optional> The path of config.json", required=False)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def preprocess_forward_content(forward_content):
|
|
470
|
+
processed_content = {}
|
|
471
|
+
base_keys_variants = {}
|
|
472
|
+
arg_cache = {}
|
|
473
|
+
|
|
474
|
+
for key, value in forward_content.items():
|
|
475
|
+
base_key = key.rsplit(Const.SEP, 1)[0]
|
|
476
|
+
|
|
477
|
+
if key not in arg_cache:
|
|
478
|
+
filtered_new_args = [
|
|
479
|
+
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
|
|
480
|
+
for arg in value['input_args'] if isinstance(arg, dict)
|
|
481
|
+
]
|
|
482
|
+
arg_cache[key] = (filtered_new_args, value['input_kwargs'])
|
|
483
|
+
|
|
484
|
+
filtered_new_args, new_kwargs = arg_cache[key]
|
|
485
|
+
|
|
486
|
+
if base_key not in base_keys_variants:
|
|
487
|
+
processed_content[key] = value
|
|
488
|
+
base_keys_variants[base_key] = {key}
|
|
489
|
+
else:
|
|
490
|
+
is_duplicate = False
|
|
491
|
+
for variant in base_keys_variants.get(base_key, []):
|
|
492
|
+
try:
|
|
493
|
+
existing_args, existing_kwargs = arg_cache.get(variant)
|
|
494
|
+
except KeyError as e:
|
|
495
|
+
logger.error(f"KeyError: {e} when processing {key}")
|
|
496
|
+
if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
|
|
497
|
+
is_duplicate = True
|
|
498
|
+
break
|
|
499
|
+
|
|
500
|
+
if not is_duplicate:
|
|
501
|
+
processed_content[key] = value
|
|
502
|
+
base_keys_variants[base_key].add(key)
|
|
503
|
+
|
|
504
|
+
return processed_content
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def _run_ut(parser=None):
|
|
508
|
+
if not parser:
|
|
509
|
+
parser = argparse.ArgumentParser()
|
|
510
|
+
_run_ut_parser(parser)
|
|
511
|
+
args = parser.parse_args(sys.argv[1:])
|
|
512
|
+
run_ut_command(args)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def run_ut_command(args):
|
|
516
|
+
if not is_gpu:
|
|
517
|
+
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
518
|
+
used_device = current_device + ":" + str(args.device_id[0])
|
|
519
|
+
try:
|
|
520
|
+
if is_gpu:
|
|
521
|
+
torch.cuda.set_device(used_device)
|
|
522
|
+
else:
|
|
523
|
+
torch.npu.set_device(used_device)
|
|
524
|
+
except Exception as error:
|
|
525
|
+
logger.error(f"Set device id failed. device id is: {args.device_id}")
|
|
526
|
+
raise NotImplementedError from error
|
|
527
|
+
|
|
528
|
+
# 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
|
|
529
|
+
# 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
|
|
530
|
+
forward_content, backward_content, real_data_path = None, None, None
|
|
531
|
+
if args.api_info_file:
|
|
532
|
+
api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
|
|
533
|
+
ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
|
|
534
|
+
checked_api_info = api_info_file_checker.common_check()
|
|
535
|
+
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
536
|
+
if args.filter_api:
|
|
537
|
+
logger.info("Start filtering the api in the forward_input_file.")
|
|
538
|
+
forward_content = preprocess_forward_content(forward_content)
|
|
539
|
+
logger.info("Finish filtering the api in the forward_input_file.")
|
|
540
|
+
|
|
541
|
+
out_path = os.path.realpath(args.out_path) if args.out_path else "./"
|
|
542
|
+
check_path_before_create(out_path)
|
|
543
|
+
create_directory(out_path)
|
|
544
|
+
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
545
|
+
out_path = out_path_checker.common_check()
|
|
546
|
+
save_error_data = args.save_error_data
|
|
547
|
+
|
|
548
|
+
result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
|
|
549
|
+
details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
|
|
550
|
+
if args.result_csv_path:
|
|
551
|
+
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
552
|
+
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
553
|
+
white_list = msCheckerConfig.white_list
|
|
554
|
+
black_list = msCheckerConfig.black_list
|
|
555
|
+
error_data_path = msCheckerConfig.error_data_path
|
|
556
|
+
is_online = msCheckerConfig.is_online
|
|
557
|
+
nfs_path = msCheckerConfig.nfs_path
|
|
558
|
+
host = msCheckerConfig.host
|
|
559
|
+
port = msCheckerConfig.port
|
|
560
|
+
rank_list = msCheckerConfig.rank_list
|
|
561
|
+
tls_path = msCheckerConfig.tls_path
|
|
562
|
+
if args.config_path:
|
|
563
|
+
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
564
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
565
|
+
checked_config_path = config_path_checker.common_check()
|
|
566
|
+
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
567
|
+
white_list = task_config.white_list
|
|
568
|
+
black_list = task_config.black_list
|
|
569
|
+
error_data_path = task_config.error_data_path
|
|
570
|
+
is_online = task_config.is_online
|
|
571
|
+
nfs_path = task_config.nfs_path
|
|
572
|
+
host = task_config.host
|
|
573
|
+
port = task_config.port
|
|
574
|
+
rank_list = task_config.rank_list
|
|
575
|
+
tls_path = task_config.tls_path
|
|
576
|
+
|
|
577
|
+
if save_error_data:
|
|
578
|
+
if args.result_csv_path:
|
|
579
|
+
time_info = result_csv_path.split('.')[0].split('_')[-1]
|
|
580
|
+
global UT_ERROR_DATA_DIR
|
|
581
|
+
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
582
|
+
error_data_path = initialize_save_error_data(error_data_path)
|
|
583
|
+
online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
|
|
584
|
+
run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
|
|
585
|
+
args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
|
|
586
|
+
online_config)
|
|
587
|
+
run_ut(run_ut_config)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
if __name__ == '__main__':
|
|
591
|
+
_run_ut()
|
|
592
|
+
logger.info("UT task completed.")
|