mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import os
|
|
3
3
|
import csv
|
|
4
|
-
import re
|
|
5
4
|
import sys
|
|
6
5
|
import time
|
|
7
6
|
import gc
|
|
@@ -18,28 +17,35 @@ else:
|
|
|
18
17
|
import torch
|
|
19
18
|
from tqdm import tqdm
|
|
20
19
|
|
|
21
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api
|
|
20
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
|
|
21
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api
|
|
22
22
|
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
23
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
24
24
|
initialize_save_path, UtDataProcessor
|
|
25
25
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
26
26
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
27
|
-
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
28
|
-
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
29
|
-
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
30
27
|
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
|
|
31
28
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
32
29
|
from msprobe.core.common.file_check import FileOpen, FileChecker, \
|
|
33
|
-
change_mode,
|
|
30
|
+
change_mode, check_path_before_create, create_directory
|
|
34
31
|
from msprobe.pytorch.common.log import logger
|
|
32
|
+
from msprobe.core.common.utils import get_json_contents
|
|
33
|
+
from msprobe.pytorch.pt_config import parse_json_config
|
|
35
34
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData, move2device_exec
|
|
36
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
37
|
+
|
|
36
38
|
|
|
37
39
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
38
40
|
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
|
|
39
41
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
40
42
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
41
43
|
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
42
|
-
'save_error_data', 'is_continue_run_ut', 'real_data_path'
|
|
44
|
+
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
45
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
46
|
+
|
|
47
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
48
|
+
|
|
43
49
|
not_backward_list = ['repeat_interleave']
|
|
44
50
|
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
45
51
|
not_raise_dtype_set = {'type_as'}
|
|
@@ -66,19 +72,6 @@ tqdm_params = {
|
|
|
66
72
|
}
|
|
67
73
|
|
|
68
74
|
|
|
69
|
-
def exec_api(api_type, api_name, args, kwargs):
|
|
70
|
-
if api_type == "Functional":
|
|
71
|
-
functional_api = FunctionalOPTemplate(api_name, str, False)
|
|
72
|
-
out = functional_api.forward(*args, **kwargs)
|
|
73
|
-
if api_type == "Tensor":
|
|
74
|
-
tensor_api = TensorOPTemplate(api_name, str, False)
|
|
75
|
-
out = tensor_api.forward(*args, **kwargs)
|
|
76
|
-
if api_type == "Torch":
|
|
77
|
-
torch_api = TorchOPTemplate(api_name, str, False)
|
|
78
|
-
out = torch_api.forward(*args, **kwargs)
|
|
79
|
-
return out
|
|
80
|
-
|
|
81
|
-
|
|
82
75
|
def deal_detach(arg, to_detach=True):
|
|
83
76
|
return arg.detach() if to_detach else arg
|
|
84
77
|
|
|
@@ -130,7 +123,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
130
123
|
elif isinstance(arg_in, torch.Tensor):
|
|
131
124
|
if need_backward and arg_in.requires_grad:
|
|
132
125
|
arg_in = deal_detach(raise_bench_data_dtype(
|
|
133
|
-
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
126
|
+
api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
|
|
134
127
|
temp_arg_in = arg_in * 1
|
|
135
128
|
arg_in = temp_arg_in.type_as(arg_in)
|
|
136
129
|
arg_in.retain_grad()
|
|
@@ -173,32 +166,48 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
173
166
|
|
|
174
167
|
def run_ut(config):
|
|
175
168
|
logger.info("start UT test")
|
|
176
|
-
|
|
177
|
-
|
|
169
|
+
if config.online_config.is_online:
|
|
170
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
|
|
171
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
|
|
172
|
+
else:
|
|
173
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
174
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
175
|
+
|
|
178
176
|
if config.save_error_data:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
177
|
+
logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
|
|
178
|
+
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
179
|
+
|
|
180
|
+
if config.online_config.is_online:
|
|
181
|
+
run_api_online(config, compare)
|
|
182
|
+
else:
|
|
183
|
+
with FileOpen(config.result_csv_path, 'r') as file:
|
|
184
|
+
csv_reader = csv.reader(file)
|
|
185
|
+
next(csv_reader)
|
|
186
|
+
api_name_set = {row[0] for row in csv_reader}
|
|
187
|
+
run_api_offline(config, compare, api_name_set)
|
|
188
|
+
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
189
|
+
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
190
|
+
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
191
|
+
logger.info(f"UT task result csv is saved in {result_csv_path}")
|
|
192
|
+
logger.info(f"UT task details csv is saved in {details_csv_path}")
|
|
193
|
+
compare.print_pretest_result()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def run_api_offline(config, compare, api_name_set):
|
|
186
197
|
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
187
198
|
if api_full_name in api_name_set:
|
|
188
199
|
continue
|
|
189
|
-
if is_unsupported_api(api_full_name):
|
|
200
|
+
if is_unsupported_api(api_full_name):
|
|
190
201
|
continue
|
|
202
|
+
[_, api_name, _] = api_full_name.split(Const.SEP)
|
|
191
203
|
try:
|
|
192
|
-
if
|
|
193
|
-
|
|
194
|
-
if api_name not in set(msCheckerConfig.white_list):
|
|
195
|
-
continue
|
|
204
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
205
|
+
continue
|
|
196
206
|
data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
|
|
197
207
|
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
|
|
198
208
|
if config.save_error_data:
|
|
199
|
-
do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success)
|
|
209
|
+
do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
|
|
200
210
|
except Exception as err:
|
|
201
|
-
[_, api_name, _] = api_full_name.split(Const.SEP)
|
|
202
211
|
if "expected scalar type Long" in str(err):
|
|
203
212
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
204
213
|
f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
@@ -214,9 +223,71 @@ def run_ut(config):
|
|
|
214
223
|
else:
|
|
215
224
|
torch.npu.empty_cache()
|
|
216
225
|
gc.collect()
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def run_api_online(config, compare):
|
|
229
|
+
attl = init_attl(config.online_config)
|
|
230
|
+
dispatcher = ConsumerDispatcher(compare=compare)
|
|
231
|
+
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
232
|
+
|
|
233
|
+
def tcp_communication_flow():
|
|
234
|
+
while True:
|
|
235
|
+
api_data = attl.recv()
|
|
236
|
+
if api_data == 'STOP_':
|
|
237
|
+
continue
|
|
238
|
+
if api_data == 'KILL_':
|
|
239
|
+
time.sleep(1)
|
|
240
|
+
logger.info("==========接收到STOP信号==========")
|
|
241
|
+
dispatcher.stop()
|
|
242
|
+
attl.stop_serve()
|
|
243
|
+
time.sleep(1)
|
|
244
|
+
break
|
|
245
|
+
if not isinstance(api_data, ApiData):
|
|
246
|
+
continue
|
|
247
|
+
api_full_name = api_data.name
|
|
248
|
+
[_, api_name, _] = api_full_name.split(Const.SEP)
|
|
249
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
250
|
+
continue
|
|
251
|
+
dispatcher.update_consume_queue(api_data)
|
|
252
|
+
|
|
253
|
+
def shared_storage_communication_flow():
|
|
254
|
+
flag_num = -1
|
|
255
|
+
while True:
|
|
256
|
+
api_data = attl.download()
|
|
257
|
+
if api_data == "start":
|
|
258
|
+
if flag_num == -1:
|
|
259
|
+
flag_num += 1
|
|
260
|
+
flag_num += 1
|
|
261
|
+
if api_data == "end":
|
|
262
|
+
flag_num -= 1
|
|
263
|
+
if flag_num == 0:
|
|
264
|
+
dispatcher.stop()
|
|
265
|
+
break
|
|
266
|
+
if not isinstance(api_data, ApiData):
|
|
267
|
+
continue
|
|
268
|
+
api_full_name = api_data.name
|
|
269
|
+
[_, api_name, _] = api_full_name.split(Const.SEP)
|
|
270
|
+
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
271
|
+
continue
|
|
272
|
+
dispatcher.update_consume_queue(api_data)
|
|
273
|
+
|
|
274
|
+
if config.online_config.nfs_path:
|
|
275
|
+
shared_storage_communication_flow()
|
|
276
|
+
else:
|
|
277
|
+
tcp_communication_flow()
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
281
|
+
"""
|
|
282
|
+
run api(api_name) if api_name not in black_list and in white_list.
|
|
283
|
+
If api is both in black_list and black_list, black_list first.
|
|
284
|
+
return: False for exec api, True for not exec
|
|
285
|
+
"""
|
|
286
|
+
if black_list and api_name in black_list:
|
|
287
|
+
return True
|
|
288
|
+
if white_list and api_name not in white_list:
|
|
289
|
+
return True
|
|
290
|
+
return False
|
|
220
291
|
|
|
221
292
|
|
|
222
293
|
def is_unsupported_api(api_name):
|
|
@@ -227,16 +298,16 @@ def is_unsupported_api(api_name):
|
|
|
227
298
|
return flag
|
|
228
299
|
|
|
229
300
|
|
|
230
|
-
def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success):
|
|
301
|
+
def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
|
|
231
302
|
if not is_fwd_success or not is_bwd_success:
|
|
232
|
-
processor = UtDataProcessor(
|
|
303
|
+
processor = UtDataProcessor(error_data_path)
|
|
233
304
|
for element in data_info.in_fwd_data_list:
|
|
234
305
|
processor.save_tensors_in_element(api_full_name + '.forward.input', element)
|
|
235
|
-
processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.
|
|
236
|
-
processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.
|
|
306
|
+
processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
|
|
307
|
+
processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
|
|
237
308
|
processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
|
|
238
|
-
processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.
|
|
239
|
-
processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.
|
|
309
|
+
processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
|
|
310
|
+
processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
|
|
240
311
|
|
|
241
312
|
|
|
242
313
|
def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
|
|
@@ -273,7 +344,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
273
344
|
|
|
274
345
|
if need_backward:
|
|
275
346
|
if need_to_backward(grad_index, out):
|
|
276
|
-
backward_args = backward_content[api_full_name].get("
|
|
347
|
+
backward_args = backward_content[api_full_name].get("input")
|
|
277
348
|
grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
|
|
278
349
|
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
|
|
279
350
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
@@ -285,6 +356,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
285
356
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
286
357
|
|
|
287
358
|
|
|
359
|
+
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
360
|
+
in_fwd_data_list = []
|
|
361
|
+
[api_type, api_name, _] = api_full_name.split(Const.SEP)
|
|
362
|
+
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
363
|
+
in_fwd_data_list.append(args)
|
|
364
|
+
in_fwd_data_list.append(kwargs)
|
|
365
|
+
if kwargs.get("device"):
|
|
366
|
+
del kwargs["device"]
|
|
367
|
+
|
|
368
|
+
device_out = exec_api(api_type, api_name, args, kwargs)
|
|
369
|
+
device_out = move2device_exec(device_out, "cpu")
|
|
370
|
+
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
371
|
+
|
|
372
|
+
|
|
288
373
|
def get_api_info(api_info_dict, api_name, real_data_path):
|
|
289
374
|
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
290
375
|
need_grad = True
|
|
@@ -314,45 +399,31 @@ def run_backward(args, grad, grad_index, out):
|
|
|
314
399
|
return grad_out
|
|
315
400
|
|
|
316
401
|
|
|
317
|
-
def initialize_save_error_data():
|
|
318
|
-
error_data_path = msCheckerConfig.error_data_path
|
|
402
|
+
def initialize_save_error_data(error_data_path):
|
|
319
403
|
check_path_before_create(error_data_path)
|
|
320
404
|
create_directory(error_data_path)
|
|
321
|
-
error_data_path_checker = FileChecker(
|
|
405
|
+
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
322
406
|
ability=FileCheckConst.WRITE_ABLE)
|
|
323
407
|
error_data_path = error_data_path_checker.common_check()
|
|
324
|
-
initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
408
|
+
error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
|
|
409
|
+
return error_data_path
|
|
325
410
|
|
|
326
411
|
|
|
327
|
-
def
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
336
|
-
if not re.match(pattern, result_csv_name):
|
|
337
|
-
raise ValueError("When continue run ut, please do not modify the result csv name.")
|
|
338
|
-
return validated_result_csv_path
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
def get_validated_details_csv_path(validated_result_csv_path):
|
|
342
|
-
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
343
|
-
details_csv_name = result_csv_name.replace('result', 'details')
|
|
344
|
-
details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
|
|
345
|
-
details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
|
|
346
|
-
ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
|
|
347
|
-
validated_details_csv_path = details_csv_path_checker.common_check()
|
|
348
|
-
return validated_details_csv_path
|
|
412
|
+
def init_attl(config):
|
|
413
|
+
"""config: OnlineConfig"""
|
|
414
|
+
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
415
|
+
connect_ip=config.host,
|
|
416
|
+
connect_port=config.port,
|
|
417
|
+
nfs_path=config.nfs_path,
|
|
418
|
+
tls_path=config.tls_path))
|
|
419
|
+
return attl
|
|
349
420
|
|
|
350
421
|
|
|
351
422
|
def _run_ut_parser(parser):
|
|
352
423
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
353
|
-
help="<
|
|
424
|
+
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
354
425
|
"a json file.",
|
|
355
|
-
required=
|
|
426
|
+
required=False)
|
|
356
427
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
357
428
|
help="<optional> The ut task result out path.",
|
|
358
429
|
required=False)
|
|
@@ -378,12 +449,10 @@ def _run_ut_parser(parser):
|
|
|
378
449
|
help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
|
|
379
450
|
"when run ut is interrupted, enter the file path to continue run ut.",
|
|
380
451
|
required=False)
|
|
381
|
-
parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
|
|
382
|
-
help="<optional> In real data mode, the root directory for storing real data "
|
|
383
|
-
"must be configured.",
|
|
384
|
-
required=False)
|
|
385
452
|
parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
|
|
386
453
|
help="<optional> Whether to filter the api in the api_info_file.", required=False)
|
|
454
|
+
parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
|
|
455
|
+
help="<optional> The path of config.json", required=False)
|
|
387
456
|
|
|
388
457
|
|
|
389
458
|
def preprocess_forward_content(forward_content):
|
|
@@ -397,9 +466,9 @@ def preprocess_forward_content(forward_content):
|
|
|
397
466
|
if key not in arg_cache:
|
|
398
467
|
filtered_new_args = [
|
|
399
468
|
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
|
|
400
|
-
for arg in value['
|
|
469
|
+
for arg in value['input_args'] if isinstance(arg, dict)
|
|
401
470
|
]
|
|
402
|
-
arg_cache[key] = (filtered_new_args, value['
|
|
471
|
+
arg_cache[key] = (filtered_new_args, value['input_kwargs'])
|
|
403
472
|
|
|
404
473
|
filtered_new_args, new_kwargs = arg_cache[key]
|
|
405
474
|
|
|
@@ -444,50 +513,69 @@ def run_ut_command(args):
|
|
|
444
513
|
except Exception as error:
|
|
445
514
|
logger.error(f"Set device id failed. device id is: {args.device_id}")
|
|
446
515
|
raise NotImplementedError from error
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
516
|
+
|
|
517
|
+
# 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
|
|
518
|
+
# 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
|
|
519
|
+
forward_content, backward_content, real_data_path = None, None, None
|
|
520
|
+
if args.api_info_file:
|
|
521
|
+
api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
|
|
522
|
+
ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
|
|
523
|
+
checked_api_info = api_info_file_checker.common_check()
|
|
524
|
+
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
525
|
+
if args.filter_api:
|
|
526
|
+
logger.info("Start filtering the api in the forward_input_file.")
|
|
527
|
+
forward_content = preprocess_forward_content(forward_content)
|
|
528
|
+
logger.info("Finish filtering the api in the forward_input_file.")
|
|
529
|
+
|
|
450
530
|
out_path = os.path.realpath(args.out_path) if args.out_path else "./"
|
|
451
531
|
check_path_before_create(out_path)
|
|
452
532
|
create_directory(out_path)
|
|
453
533
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
454
534
|
out_path = out_path_checker.common_check()
|
|
455
535
|
save_error_data = args.save_error_data
|
|
456
|
-
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info)
|
|
457
|
-
if args.filter_api:
|
|
458
|
-
logger.info("Start filtering the api in the forward_input_file.")
|
|
459
|
-
forward_content = preprocess_forward_content(forward_content)
|
|
460
|
-
logger.info("Finish filtering the api in the forward_input_file.")
|
|
461
536
|
|
|
462
537
|
result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
|
|
463
538
|
details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
|
|
464
539
|
if args.result_csv_path:
|
|
465
540
|
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
466
541
|
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
542
|
+
white_list = msCheckerConfig.white_list
|
|
543
|
+
black_list = msCheckerConfig.black_list
|
|
544
|
+
error_data_path = msCheckerConfig.error_data_path
|
|
545
|
+
is_online = msCheckerConfig.is_online
|
|
546
|
+
nfs_path = msCheckerConfig.nfs_path
|
|
547
|
+
host = msCheckerConfig.host
|
|
548
|
+
port = msCheckerConfig.port
|
|
549
|
+
rank_list = msCheckerConfig.rank_list
|
|
550
|
+
tls_path = msCheckerConfig.tls_path
|
|
551
|
+
if args.config_path:
|
|
552
|
+
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
553
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
554
|
+
checked_config_path = config_path_checker.common_check()
|
|
555
|
+
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
556
|
+
white_list = task_config.white_list
|
|
557
|
+
black_list = task_config.black_list
|
|
558
|
+
error_data_path = task_config.error_data_path
|
|
559
|
+
is_online = task_config.is_online
|
|
560
|
+
nfs_path = task_config.nfs_path
|
|
561
|
+
host = task_config.host
|
|
562
|
+
port = task_config.port
|
|
563
|
+
rank_list = task_config.rank_list
|
|
564
|
+
tls_path = task_config.tls_path
|
|
565
|
+
|
|
467
566
|
if save_error_data:
|
|
468
567
|
if args.result_csv_path:
|
|
469
568
|
time_info = result_csv_path.split('.')[0].split('_')[-1]
|
|
470
569
|
global UT_ERROR_DATA_DIR
|
|
471
570
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
472
|
-
initialize_save_error_data()
|
|
571
|
+
error_data_path = initialize_save_error_data(error_data_path)
|
|
572
|
+
online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
|
|
473
573
|
run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
|
|
474
|
-
args.result_csv_path, real_data_path)
|
|
574
|
+
args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
|
|
575
|
+
online_config)
|
|
475
576
|
run_ut(run_ut_config)
|
|
476
577
|
|
|
477
578
|
|
|
478
|
-
class UtDataInfo:
|
|
479
|
-
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
480
|
-
backward_message, rank=0):
|
|
481
|
-
self.bench_grad = bench_grad
|
|
482
|
-
self.device_grad = device_grad
|
|
483
|
-
self.device_output = device_output
|
|
484
|
-
self.bench_output = bench_output
|
|
485
|
-
self.grad_in = grad_in
|
|
486
|
-
self.in_fwd_data_list = in_fwd_data_list
|
|
487
|
-
self.backward_message = backward_message
|
|
488
|
-
self.rank = rank
|
|
489
|
-
|
|
490
|
-
|
|
491
579
|
if __name__ == '__main__':
|
|
492
580
|
_run_ut()
|
|
493
581
|
logger.info("UT task completed.")
|
|
@@ -1,7 +1,74 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from msprobe.core.common.const import FileCheckConst
|
|
5
|
+
from msprobe.core.common.file_check import FileChecker
|
|
6
|
+
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
7
|
+
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
|
|
8
|
+
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
9
|
+
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
10
|
+
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
11
|
+
|
|
1
12
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
2
13
|
|
|
3
14
|
|
|
4
15
|
class Backward_Message:
|
|
5
16
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
6
17
|
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
|
|
7
|
-
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
18
|
+
NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UtDataInfo:
|
|
22
|
+
def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
|
|
23
|
+
backward_message, rank=0):
|
|
24
|
+
self.bench_grad = bench_grad
|
|
25
|
+
self.device_grad = device_grad
|
|
26
|
+
self.device_output = device_output
|
|
27
|
+
self.bench_output = bench_output
|
|
28
|
+
self.grad_in = grad_in
|
|
29
|
+
self.in_fwd_data_list = in_fwd_data_list
|
|
30
|
+
self.backward_message = backward_message
|
|
31
|
+
self.rank = rank
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_validated_result_csv_path(result_csv_path, mode):
|
|
35
|
+
if mode not in ['result', 'detail']:
|
|
36
|
+
raise ValueError("The csv mode must be result or detail")
|
|
37
|
+
result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
|
|
38
|
+
file_type=FileCheckConst.CSV_SUFFIX)
|
|
39
|
+
validated_result_csv_path = result_csv_path_checker.common_check()
|
|
40
|
+
if mode == 'result':
|
|
41
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
42
|
+
pattern = r"^accuracy_checking_result_\d{14}\.csv$"
|
|
43
|
+
if not re.match(pattern, result_csv_name):
|
|
44
|
+
raise ValueError("When continue run ut, please do not modify the result csv name.")
|
|
45
|
+
return validated_result_csv_path
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_validated_details_csv_path(validated_result_csv_path):
|
|
49
|
+
result_csv_name = os.path.basename(validated_result_csv_path)
|
|
50
|
+
details_csv_name = result_csv_name.replace('result', 'details')
|
|
51
|
+
details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
|
|
52
|
+
details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
|
|
53
|
+
ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
|
|
54
|
+
validated_details_csv_path = details_csv_path_checker.common_check()
|
|
55
|
+
return validated_details_csv_path
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def exec_api(api_type, api_name, args, kwargs):
|
|
59
|
+
if api_type == "Functional":
|
|
60
|
+
functional_api = FunctionalOPTemplate(api_name, str, False)
|
|
61
|
+
out = functional_api.forward(*args, **kwargs)
|
|
62
|
+
if api_type == "Tensor":
|
|
63
|
+
tensor_api = TensorOPTemplate(api_name, str, False)
|
|
64
|
+
out = tensor_api.forward(*args, **kwargs)
|
|
65
|
+
if api_type == "Torch":
|
|
66
|
+
torch_api = TorchOPTemplate(api_name, str, False)
|
|
67
|
+
out = torch_api.forward(*args, **kwargs)
|
|
68
|
+
if api_type == "Aten":
|
|
69
|
+
torch_api = AtenOPTemplate(api_name, None, False)
|
|
70
|
+
out = torch_api.forward(*args, **kwargs)
|
|
71
|
+
if api_type == "NPU":
|
|
72
|
+
torch_api = NpuOPTemplate(api_name, None, False)
|
|
73
|
+
out = torch_api.forward(*args, **kwargs)
|
|
74
|
+
return out
|
|
File without changes
|