mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -2,13 +2,11 @@ import time
|
|
|
2
2
|
import os
|
|
3
3
|
import math
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
5
|
import torch
|
|
7
|
-
|
|
8
|
-
from msprobe.core.common.utils import CompareException
|
|
6
|
+
|
|
7
|
+
from msprobe.core.common.utils import CompareException, load_yaml
|
|
9
8
|
from msprobe.core.common.const import Const
|
|
10
9
|
from msprobe.pytorch.common.log import logger
|
|
11
|
-
from msprobe.core.common.file_check import FileOpen
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
@@ -22,17 +20,15 @@ BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_C
|
|
|
22
20
|
|
|
23
21
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
24
22
|
standard_yaml_path = os.path.join(cur_path, "api_precision_standard.yaml")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
ThousandthStandardApi = Apis.get('ThousandthStandard')
|
|
23
|
+
apis = load_yaml(standard_yaml_path)
|
|
24
|
+
absolute_standard_api = apis.get('AbsoluteThreshStandard')
|
|
25
|
+
binary_standard_api = apis.get('BinaryCompareStandard')
|
|
26
|
+
ulp_standard_api = apis.get('ULPStandard')
|
|
27
|
+
thousandth_standard_api = apis.get('ThousandthStandard')
|
|
31
28
|
|
|
32
29
|
|
|
33
30
|
threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
|
|
34
|
-
|
|
35
|
-
apis_threshold = yaml.safe_load(f)
|
|
31
|
+
apis_threshold = load_yaml(threshold_yaml_path)
|
|
36
32
|
|
|
37
33
|
|
|
38
34
|
DETAIL_TEST_ROWS = [[
|
|
@@ -21,10 +21,11 @@ import torch
|
|
|
21
21
|
import numpy
|
|
22
22
|
|
|
23
23
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
|
|
24
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import
|
|
25
|
-
|
|
24
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
|
|
25
|
+
CompareException
|
|
26
|
+
from msprobe.core.common.file_check import FileChecker
|
|
26
27
|
from msprobe.pytorch.common.log import logger
|
|
27
|
-
from msprobe.core.common.const import Const
|
|
28
|
+
from msprobe.core.common.const import Const, FileCheckConst
|
|
28
29
|
|
|
29
30
|
TORCH_TYPE = ["torch.device", "torch.dtype"]
|
|
30
31
|
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
|
|
@@ -87,12 +88,13 @@ def gen_real_tensor(data_path, convert_type):
|
|
|
87
88
|
convert_type: convert ori_type to dist_type flag.
|
|
88
89
|
"""
|
|
89
90
|
data_path = os.path.realpath(data_path)
|
|
90
|
-
|
|
91
|
+
data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
|
|
92
|
+
data_path = data_path_checker.common_check()
|
|
91
93
|
if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
|
|
92
94
|
error_info = f"The file: {data_path} is not a pt or numpy file."
|
|
93
95
|
raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
|
|
94
96
|
if data_path.endswith('.pt'):
|
|
95
|
-
data = torch.load(data_path
|
|
97
|
+
data = torch.load(data_path, map_location=torch.device('cpu'))
|
|
96
98
|
else:
|
|
97
99
|
data_np = numpy.load(data_path)
|
|
98
100
|
data = torch.from_numpy(data_np)
|
|
@@ -255,12 +257,13 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
|
|
|
255
257
|
return args_result
|
|
256
258
|
|
|
257
259
|
|
|
258
|
-
def gen_kwargs(api_info, convert_type=None, real_data_path=None):
|
|
260
|
+
def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
|
|
259
261
|
"""
|
|
260
262
|
Function Description:
|
|
261
263
|
Based on API basic information, generate input parameters: kwargs, for API forward running
|
|
262
264
|
Parameter:
|
|
263
265
|
api_info: API basic information. Dict
|
|
266
|
+
api_name: API name
|
|
264
267
|
convert_type: convert ori_type to dist_type flag.
|
|
265
268
|
real_data_path: the root directory for storing real data.
|
|
266
269
|
"""
|
|
@@ -268,11 +271,11 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None):
|
|
|
268
271
|
kwargs_params = api_info.get("input_kwargs")
|
|
269
272
|
for key, value in kwargs_params.items():
|
|
270
273
|
if isinstance(value, (list, tuple)):
|
|
271
|
-
kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
|
|
274
|
+
kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
|
|
272
275
|
elif value is None:
|
|
273
276
|
kwargs_params[key] = None
|
|
274
277
|
elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
|
|
275
|
-
kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
|
|
278
|
+
kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
|
|
276
279
|
elif value.get('type') in TORCH_TYPE:
|
|
277
280
|
gen_torch_kwargs(kwargs_params, key, value)
|
|
278
281
|
else:
|
|
@@ -285,18 +288,19 @@ def gen_torch_kwargs(kwargs_params, key, value):
|
|
|
285
288
|
kwargs_params[key] = eval(value.get('value'))
|
|
286
289
|
|
|
287
290
|
|
|
288
|
-
def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
|
|
291
|
+
def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
|
|
289
292
|
"""
|
|
290
293
|
Function Description:
|
|
291
294
|
When kwargs value is list, generate the list of kwargs result
|
|
292
295
|
Parameter:
|
|
293
296
|
kwargs_item_value: kwargs value before to generate. List
|
|
297
|
+
api_name: API name
|
|
294
298
|
convert_type: convert ori_type to dist_type flag.
|
|
295
299
|
"""
|
|
296
300
|
kwargs_item_result = []
|
|
297
301
|
for item in kwargs_item_value:
|
|
298
302
|
if item.get('type') in TENSOR_DATA_LIST:
|
|
299
|
-
item_value = gen_data(item, False, convert_type, real_data_path)
|
|
303
|
+
item_value = gen_data(item, api_name, False, convert_type, real_data_path)
|
|
300
304
|
elif item.get('type') == "torch.Size":
|
|
301
305
|
item_value = torch.Size(item.get('value'))
|
|
302
306
|
else:
|
|
@@ -319,7 +323,7 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
|
|
|
319
323
|
if convert_type and convert_type not in Const.CONVERT:
|
|
320
324
|
error_info = f"convert_type params not support {convert_type}."
|
|
321
325
|
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
|
|
322
|
-
kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
|
|
326
|
+
kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
|
|
323
327
|
if api_info.get("input_args"):
|
|
324
328
|
args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
|
|
325
329
|
else:
|
|
@@ -9,8 +9,9 @@ import threading
|
|
|
9
9
|
from collections import namedtuple
|
|
10
10
|
from itertools import cycle
|
|
11
11
|
from tqdm import tqdm
|
|
12
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser,
|
|
13
|
-
|
|
12
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, preprocess_forward_content
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path, \
|
|
14
|
+
get_validated_details_csv_path
|
|
14
15
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
15
16
|
from msprobe.pytorch.common import parse_json_info_forward_backward
|
|
16
17
|
from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
|
|
@@ -68,7 +69,7 @@ signal.signal(signal.SIGTERM, signal_handler)
|
|
|
68
69
|
|
|
69
70
|
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
|
|
70
71
|
'save_error_data_flag', 'jit_compile_flag', 'device_id',
|
|
71
|
-
'result_csv_path', 'total_items', '
|
|
72
|
+
'result_csv_path', 'total_items', 'config_path'])
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
def run_parallel_ut(config):
|
|
@@ -90,7 +91,7 @@ def run_parallel_ut(config):
|
|
|
90
91
|
*(['-j'] if config.jit_compile_flag else []),
|
|
91
92
|
*(['-save_error_data'] if config.save_error_data_flag else []),
|
|
92
93
|
'-csv_path', config.result_csv_path,
|
|
93
|
-
*(['-
|
|
94
|
+
*(['-config', config.config_path] if config.config_path else [])
|
|
94
95
|
]
|
|
95
96
|
return cmd
|
|
96
97
|
|
|
@@ -110,19 +111,14 @@ def run_parallel_ut(config):
|
|
|
110
111
|
|
|
111
112
|
def update_progress_bar(progress_bar, result_csv_path):
|
|
112
113
|
while any(process.poll() is None for process in processes):
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
progress_bar.update(completed_items - progress_bar.n)
|
|
117
|
-
except FileNotFoundError:
|
|
118
|
-
logger.warning(f"Result CSV file not found: {result_csv_path}.")
|
|
119
|
-
except Exception as e:
|
|
120
|
-
logger.error(f"An unexpected error occurred while reading result CSV: {e}")
|
|
114
|
+
with FileOpen(result_csv_path, 'r') as result_file:
|
|
115
|
+
completed_items = len(result_file.readlines()) - 1
|
|
116
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
121
117
|
time.sleep(1)
|
|
122
118
|
|
|
123
119
|
for api_info in config.api_files:
|
|
124
120
|
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
125
|
-
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1)
|
|
121
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False)
|
|
126
122
|
processes.append(process)
|
|
127
123
|
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
128
124
|
|
|
@@ -175,7 +171,7 @@ def prepare_config(args):
|
|
|
175
171
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
176
172
|
out_path = out_path_checker.common_check()
|
|
177
173
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
178
|
-
|
|
174
|
+
config_path = os.path.realpath(args.config_path) if args.config_path else None
|
|
179
175
|
result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
180
176
|
if not args.result_csv_path:
|
|
181
177
|
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
@@ -187,7 +183,7 @@ def prepare_config(args):
|
|
|
187
183
|
logger.info(f"UT task details will be saved in {details_csv_path}")
|
|
188
184
|
return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
|
|
189
185
|
args.jit_compile, args.device_id, result_csv_path,
|
|
190
|
-
total_items,
|
|
186
|
+
total_items, config_path)
|
|
191
187
|
|
|
192
188
|
|
|
193
189
|
def main():
|
|
@@ -10,10 +10,14 @@ else:
|
|
|
10
10
|
is_gpu = False
|
|
11
11
|
import torch
|
|
12
12
|
from tqdm import tqdm
|
|
13
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import
|
|
14
|
-
from msprobe.pytorch.api_accuracy_checker.
|
|
13
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
14
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
|
|
15
|
+
from msprobe.core.common.utils import get_json_contents
|
|
15
16
|
from msprobe.core.common.file_check import check_link
|
|
16
17
|
from msprobe.pytorch.common.log import logger
|
|
18
|
+
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
|
|
17
21
|
|
|
18
22
|
def check_tensor_overflow(x):
|
|
19
23
|
if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
|
|
@@ -52,12 +56,12 @@ def check_data_overflow(x):
|
|
|
52
56
|
|
|
53
57
|
def run_overflow_check(forward_file):
|
|
54
58
|
logger.info("start UT test")
|
|
55
|
-
forward_content =
|
|
59
|
+
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
56
60
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
57
61
|
try:
|
|
58
|
-
run_torch_api(api_full_name, api_info_dict)
|
|
62
|
+
run_torch_api(api_full_name, api_info_dict, real_data_path)
|
|
59
63
|
except Exception as err:
|
|
60
|
-
api_name = api_full_name.split(
|
|
64
|
+
_, api_name, _ = api_full_name.split(Const.SEP)
|
|
61
65
|
if "not implemented for 'Half'" in str(err):
|
|
62
66
|
logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
|
|
63
67
|
f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
|
|
@@ -68,11 +72,10 @@ def run_overflow_check(forward_file):
|
|
|
68
72
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
69
73
|
|
|
70
74
|
|
|
71
|
-
def run_torch_api(api_full_name, api_info_dict):
|
|
75
|
+
def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
72
76
|
torch.npu.clear_npu_overflow_flag()
|
|
73
|
-
api_type = api_full_name.split(
|
|
74
|
-
|
|
75
|
-
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='')
|
|
77
|
+
api_type, api_name, _ = api_full_name.split(Const.SEP)
|
|
78
|
+
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
|
|
76
79
|
if not need_grad:
|
|
77
80
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
78
81
|
% api_full_name)
|
|
@@ -81,6 +84,10 @@ def run_torch_api(api_full_name, api_info_dict):
|
|
|
81
84
|
del kwargs["device"]
|
|
82
85
|
out = exec_api(api_type, api_name, args, kwargs)
|
|
83
86
|
npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs)
|
|
87
|
+
if out is None and npu_out is None:
|
|
88
|
+
logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
|
|
89
|
+
return
|
|
90
|
+
|
|
84
91
|
cpu_overflow = check_data_overflow(out)
|
|
85
92
|
npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
|
|
86
93
|
if cpu_overflow == npu_overflow:
|