mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -23,16 +23,19 @@ try:
|
|
|
23
23
|
import torch_npu
|
|
24
24
|
except ImportError:
|
|
25
25
|
is_gpu = True
|
|
26
|
+
current_device = "cuda"
|
|
26
27
|
else:
|
|
27
28
|
is_gpu = False
|
|
29
|
+
current_device = "npu"
|
|
28
30
|
import torch
|
|
29
31
|
from tqdm import tqdm
|
|
30
32
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
31
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
|
|
32
|
-
from msprobe.core.common.file_utils import check_link
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api, ExecParams
|
|
34
|
+
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
35
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
36
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
33
37
|
from msprobe.pytorch.common.log import logger
|
|
34
38
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
35
|
-
from msprobe.core.common.const import Const
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
def check_tensor_overflow(x):
|
|
@@ -60,52 +63,80 @@ def check_tensor_overflow(x):
|
|
|
60
63
|
return False
|
|
61
64
|
|
|
62
65
|
|
|
63
|
-
def check_data_overflow(x):
|
|
64
|
-
if isinstance(x, (tuple, list))
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
return False
|
|
66
|
+
def check_data_overflow(x, device):
|
|
67
|
+
if isinstance(x, (tuple, list)):
|
|
68
|
+
if not x:
|
|
69
|
+
return False
|
|
70
|
+
return any(check_data_overflow(item, device) for item in x)
|
|
69
71
|
else:
|
|
70
|
-
|
|
72
|
+
if device == Const.CPU_LOWERCASE:
|
|
73
|
+
return check_tensor_overflow(x)
|
|
74
|
+
else:
|
|
75
|
+
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def is_bool_output(x):
|
|
79
|
+
if isinstance(x, (tuple, list)):
|
|
80
|
+
if not x:
|
|
81
|
+
return False
|
|
82
|
+
return any(is_bool_output(item) for item in x)
|
|
83
|
+
else:
|
|
84
|
+
return isinstance(x, bool)
|
|
71
85
|
|
|
72
86
|
|
|
73
87
|
def run_overflow_check(forward_file):
|
|
74
88
|
logger.info("start UT test")
|
|
75
89
|
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
90
|
+
if real_data_path:
|
|
91
|
+
dump_path = os.path.dirname(forward_file)
|
|
92
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
76
93
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
94
|
+
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
95
|
+
continue
|
|
77
96
|
try:
|
|
78
97
|
run_torch_api(api_full_name, api_info_dict, real_data_path)
|
|
79
98
|
except Exception as err:
|
|
80
99
|
_, api_name, _ = api_full_name.split(Const.SEP)
|
|
81
100
|
if "not implemented for 'Half'" in str(err):
|
|
82
|
-
logger.warning(f"API {api_name} not support half tensor in CPU
|
|
83
|
-
|
|
101
|
+
logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
|
|
102
|
+
"check, so it will be skipped.")
|
|
84
103
|
elif "expected scalar type Long" in str(err):
|
|
85
104
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
86
|
-
|
|
105
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
106
|
+
elif "could not create a primitive descriptor for a matmul primitive" in str(err):
|
|
107
|
+
logger.warning(f"API {api_name} not support matmul primitive in CPU due to pytorch bug, "
|
|
108
|
+
"so it will be skipped.")
|
|
87
109
|
else:
|
|
88
110
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
89
111
|
|
|
90
112
|
|
|
91
113
|
def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
92
114
|
torch.npu.clear_npu_overflow_flag()
|
|
93
|
-
api_type, api_name
|
|
115
|
+
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
94
116
|
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
|
|
95
117
|
if not need_grad:
|
|
96
118
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
97
119
|
% api_full_name)
|
|
120
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
121
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
122
|
+
kwargs[Const.DEVICE] = current_device
|
|
98
123
|
npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
|
|
99
|
-
if kwargs.get(
|
|
100
|
-
del kwargs[
|
|
101
|
-
|
|
102
|
-
|
|
124
|
+
if kwargs.get(Const.DEVICE):
|
|
125
|
+
del kwargs[Const.DEVICE]
|
|
126
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None)
|
|
127
|
+
device_exec_params = ExecParams(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs, False, None)
|
|
128
|
+
out = exec_api(cpu_exec_params)
|
|
129
|
+
npu_out = exec_api(device_exec_params)
|
|
103
130
|
if out is None and npu_out is None:
|
|
104
131
|
logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
|
|
105
132
|
return
|
|
133
|
+
if is_bool_output(out) or is_bool_output(npu_out):
|
|
134
|
+
logger.warning("The output of %s is bool type.This dtype not support overflow, so it will be skipped."
|
|
135
|
+
% api_full_name)
|
|
136
|
+
return
|
|
106
137
|
|
|
107
|
-
cpu_overflow = check_data_overflow(out)
|
|
108
|
-
npu_overflow =
|
|
138
|
+
cpu_overflow = check_data_overflow(out, Const.CPU_LOWERCASE)
|
|
139
|
+
npu_overflow = check_data_overflow(npu_out, Const.NPU_LOWERCASE)
|
|
109
140
|
if cpu_overflow == npu_overflow:
|
|
110
141
|
logger.warning("The %s overflow is a normal overflow." % api_full_name)
|
|
111
142
|
else:
|
|
@@ -135,8 +166,9 @@ def _run_overflow_check(parser=None):
|
|
|
135
166
|
def _run_overflow_check_command(args):
|
|
136
167
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
137
168
|
npu_device = "npu:" + str(args.device_id)
|
|
138
|
-
|
|
139
|
-
|
|
169
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
170
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
171
|
+
api_info = api_info_file_checker.common_check()
|
|
140
172
|
try:
|
|
141
173
|
torch.npu.set_device(npu_device)
|
|
142
174
|
except Exception as error:
|
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
19
|
import os
|
|
20
|
-
import
|
|
20
|
+
import re
|
|
21
21
|
import sys
|
|
22
22
|
import time
|
|
23
23
|
import gc
|
|
@@ -31,39 +31,40 @@ except ImportError:
|
|
|
31
31
|
else:
|
|
32
32
|
is_gpu = False
|
|
33
33
|
current_device = "npu"
|
|
34
|
+
|
|
34
35
|
import torch
|
|
35
36
|
from tqdm import tqdm
|
|
36
37
|
|
|
37
38
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
|
|
38
|
-
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
|
|
39
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
|
|
39
40
|
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
40
41
|
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
41
42
|
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
42
43
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
43
44
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
44
|
-
from msprobe.pytorch.api_accuracy_checker.common.config import
|
|
45
|
+
from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
|
|
45
46
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
46
|
-
from msprobe.core.common.file_utils import FileChecker, change_mode,
|
|
47
|
-
create_directory, get_json_contents, read_csv
|
|
47
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
48
|
+
create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
|
|
48
49
|
from msprobe.pytorch.common.log import logger
|
|
49
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
50
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
53
|
+
from msprobe.pytorch.common.utils import seed_all
|
|
51
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
52
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
53
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
56
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
|
+
ExecParams
|
|
54
58
|
|
|
55
59
|
|
|
56
60
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
57
61
|
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
|
|
58
62
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
59
63
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
60
|
-
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
61
|
-
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
62
|
-
'black_list', 'error_data_path', 'online_config'])
|
|
63
64
|
|
|
64
|
-
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
65
65
|
|
|
66
66
|
not_backward_list = ['repeat_interleave']
|
|
67
|
+
unsupported_backward_list = ['masked_select']
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
tqdm_params = {
|
|
@@ -99,7 +100,11 @@ def run_ut(config):
|
|
|
99
100
|
run_api_online(config, compare)
|
|
100
101
|
else:
|
|
101
102
|
csv_df = read_csv(config.result_csv_path)
|
|
102
|
-
|
|
103
|
+
try:
|
|
104
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
105
|
+
except IndexError:
|
|
106
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
107
|
+
api_name_set = set()
|
|
103
108
|
run_api_offline(config, compare, api_name_set)
|
|
104
109
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
105
110
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -140,7 +145,7 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
140
145
|
except Exception as err:
|
|
141
146
|
if "expected scalar type Long" in str(err):
|
|
142
147
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
143
|
-
|
|
148
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
144
149
|
else:
|
|
145
150
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
146
151
|
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
@@ -220,14 +225,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
220
225
|
return False
|
|
221
226
|
|
|
222
227
|
|
|
223
|
-
def is_unsupported_api(api_name):
|
|
224
|
-
split_name = api_name.split(Const.SEP)[0]
|
|
225
|
-
flag = split_name == Const.DISTRIBUTED
|
|
226
|
-
if flag:
|
|
227
|
-
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
228
|
-
return flag
|
|
229
|
-
|
|
230
|
-
|
|
231
228
|
def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
|
|
232
229
|
if not is_fwd_success or not is_bwd_success:
|
|
233
230
|
processor = UtDataProcessor(error_data_path)
|
|
@@ -244,7 +241,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
244
241
|
in_fwd_data_list = []
|
|
245
242
|
backward_message = ''
|
|
246
243
|
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
247
|
-
args, kwargs,
|
|
244
|
+
args, kwargs, output_dtype = get_api_info(api_info_dict, api_name, real_data_path)
|
|
245
|
+
need_grad = check_need_grad(api_info_dict)
|
|
248
246
|
in_fwd_data_list.append(args)
|
|
249
247
|
in_fwd_data_list.append(kwargs)
|
|
250
248
|
need_backward = api_full_name in backward_content
|
|
@@ -253,16 +251,32 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
253
251
|
backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
|
|
254
252
|
if api_name in not_backward_list:
|
|
255
253
|
need_grad = False
|
|
256
|
-
logger.
|
|
254
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
257
255
|
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
256
|
+
if api_name in unsupported_backward_list:
|
|
257
|
+
need_grad = False
|
|
258
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_API_MESSAGE))
|
|
259
|
+
backward_message += BackwardMessage.UNSUPPORT_API_MESSAGE
|
|
258
260
|
need_backward = need_backward and need_grad
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
261
|
+
|
|
262
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
263
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
264
|
+
kwargs[Const.DEVICE] = current_device
|
|
262
265
|
device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
|
|
266
|
+
if kwargs.get(Const.DEVICE):
|
|
267
|
+
del kwargs[Const.DEVICE]
|
|
268
|
+
cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
|
|
269
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
270
|
+
autocast_dtype, is_autocast = cpu_params.autocast_dtype, cpu_params.is_autocast
|
|
271
|
+
if not is_autocast and output_dtype:
|
|
272
|
+
is_autocast = autocast_dtype != output_dtype
|
|
273
|
+
autocast_dtype = output_dtype
|
|
263
274
|
bench_grad_out, device_grad_out = None, None
|
|
264
|
-
|
|
265
|
-
|
|
275
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, autocast_dtype)
|
|
276
|
+
out = exec_api(cpu_exec_params)
|
|
277
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
|
|
278
|
+
autocast_dtype)
|
|
279
|
+
device_out = exec_api(device_exec_params)
|
|
266
280
|
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
267
281
|
ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
|
|
268
282
|
api_setting_dict = get_json_contents(ut_setting_path)
|
|
@@ -278,16 +292,18 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
278
292
|
func_options = {
|
|
279
293
|
'real_data_path': real_data_path
|
|
280
294
|
}
|
|
281
|
-
grad = gen_args(backward_args, api_name, func_options)
|
|
282
|
-
|
|
295
|
+
grad = gen_args(backward_args, api_name, func_options)
|
|
296
|
+
grad = safe_get_value(grad, 0, "grad")
|
|
297
|
+
grad_params = generate_cpu_params(grad, {}, False, api_name)
|
|
298
|
+
bench_grad = grad_params.cpu_args
|
|
283
299
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
284
300
|
device_grad = grad.clone().detach().to(current_device)
|
|
285
301
|
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
286
302
|
else:
|
|
287
303
|
backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
|
|
288
304
|
if api_name == "npu_fusion_attention":
|
|
289
|
-
out = out
|
|
290
|
-
device_out = device_out
|
|
305
|
+
out = safe_get_value(out, 0, "out")
|
|
306
|
+
device_out = safe_get_value(device_out, 0, "device_out")
|
|
291
307
|
|
|
292
308
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
293
309
|
|
|
@@ -306,13 +322,18 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
306
322
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
307
323
|
|
|
308
324
|
|
|
309
|
-
def
|
|
310
|
-
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
325
|
+
def check_need_grad(api_info_dict):
|
|
311
326
|
need_grad = True
|
|
312
|
-
if api_info_dict.get(
|
|
327
|
+
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
313
328
|
need_grad = False
|
|
314
|
-
|
|
315
|
-
|
|
329
|
+
return need_grad
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def get_api_info(api_info_dict, api_name, real_data_path):
|
|
333
|
+
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
334
|
+
need_grad = check_need_grad(api_info_dict)
|
|
335
|
+
args, kwargs, output_dtype = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
|
|
336
|
+
return args, kwargs, output_dtype
|
|
316
337
|
|
|
317
338
|
|
|
318
339
|
def need_to_backward(grad_index, out):
|
|
@@ -323,20 +344,32 @@ def need_to_backward(grad_index, out):
|
|
|
323
344
|
|
|
324
345
|
def run_backward(args, grad, grad_index, out):
|
|
325
346
|
if grad_index is not None:
|
|
347
|
+
if grad_index >= len(out):
|
|
348
|
+
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
349
|
+
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
326
350
|
out[grad_index].backward(grad)
|
|
327
351
|
else:
|
|
328
352
|
out.backward(grad)
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
if isinstance(arg, torch.Tensor):
|
|
332
|
-
args_grad.append(arg.grad)
|
|
333
|
-
grad_out = args_grad
|
|
353
|
+
|
|
354
|
+
grad_out = extract_tensors_grad(args)
|
|
334
355
|
|
|
335
356
|
return grad_out
|
|
336
357
|
|
|
337
358
|
|
|
359
|
+
def extract_tensors_grad(args, depth=0):
|
|
360
|
+
if depth > Const.MAX_DEPTH:
|
|
361
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
362
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
363
|
+
grads = []
|
|
364
|
+
for arg in args:
|
|
365
|
+
if isinstance(arg, torch.Tensor):
|
|
366
|
+
grads.append(arg.grad)
|
|
367
|
+
elif isinstance(arg, (list, tuple)):
|
|
368
|
+
grads.extend(extract_tensors_grad(arg, depth+1))
|
|
369
|
+
return grads
|
|
370
|
+
|
|
371
|
+
|
|
338
372
|
def initialize_save_error_data(error_data_path):
|
|
339
|
-
check_path_before_create(error_data_path)
|
|
340
373
|
create_directory(error_data_path)
|
|
341
374
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
342
375
|
ability=FileCheckConst.WRITE_ABLE)
|
|
@@ -438,9 +471,55 @@ def _run_ut(parser=None):
|
|
|
438
471
|
run_ut_command(args)
|
|
439
472
|
|
|
440
473
|
|
|
474
|
+
def checked_online_config(online_config):
|
|
475
|
+
if not online_config.is_online:
|
|
476
|
+
return
|
|
477
|
+
if not isinstance(online_config.is_online, bool):
|
|
478
|
+
raise ValueError("is_online must be bool type")
|
|
479
|
+
# rank_list
|
|
480
|
+
if not isinstance(online_config.rank_list, list):
|
|
481
|
+
raise ValueError("rank_list must be a list")
|
|
482
|
+
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
483
|
+
raise ValueError("All elements in rank_list must be integers")
|
|
484
|
+
|
|
485
|
+
# nfs_path
|
|
486
|
+
if online_config.nfs_path:
|
|
487
|
+
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
488
|
+
return
|
|
489
|
+
# tls_path
|
|
490
|
+
if online_config.tls_path:
|
|
491
|
+
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
492
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
493
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
494
|
+
check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
|
|
495
|
+
|
|
496
|
+
# host and port
|
|
497
|
+
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
498
|
+
raise Exception(f"host: {online_config.host} is invalid.")
|
|
499
|
+
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
500
|
+
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
501
|
+
|
|
502
|
+
|
|
441
503
|
def run_ut_command(args):
|
|
504
|
+
if args.config_path:
|
|
505
|
+
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
506
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
507
|
+
checked_config_path = config_path_checker.common_check()
|
|
508
|
+
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
509
|
+
checker_config = CheckerConfig(task_config)
|
|
510
|
+
else:
|
|
511
|
+
checker_config = CheckerConfig()
|
|
512
|
+
|
|
513
|
+
if not checker_config.is_online and not args.api_info_file:
|
|
514
|
+
logger.error("Please provide api_info_file for offline run ut.")
|
|
515
|
+
raise Exception("Please provide api_info_file for offline run ut.")
|
|
516
|
+
|
|
442
517
|
if not is_gpu:
|
|
443
518
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
519
|
+
if args.jit_compile:
|
|
520
|
+
torch.npu.config.allow_internal_format = True
|
|
521
|
+
else:
|
|
522
|
+
torch.npu.config.allow_internal_format = False
|
|
444
523
|
used_device = current_device + ":" + str(args.device_id[0])
|
|
445
524
|
try:
|
|
446
525
|
if is_gpu:
|
|
@@ -459,13 +538,15 @@ def run_ut_command(args):
|
|
|
459
538
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
460
539
|
checked_api_info = api_info_file_checker.common_check()
|
|
461
540
|
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
541
|
+
if real_data_path:
|
|
542
|
+
dump_path = os.path.dirname(checked_api_info)
|
|
543
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
462
544
|
if args.filter_api:
|
|
463
545
|
logger.info("Start filtering the api in the api_info_file.")
|
|
464
546
|
forward_content = preprocess_forward_content(forward_content)
|
|
465
547
|
logger.info("Finish filtering the api in the api_info_file.")
|
|
466
548
|
|
|
467
|
-
out_path =
|
|
468
|
-
check_path_before_create(out_path)
|
|
549
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
469
550
|
create_directory(out_path)
|
|
470
551
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
471
552
|
out_path = out_path_checker.common_check()
|
|
@@ -476,43 +557,31 @@ def run_ut_command(args):
|
|
|
476
557
|
if args.result_csv_path:
|
|
477
558
|
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
478
559
|
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
479
|
-
white_list = msCheckerConfig.white_list
|
|
480
|
-
black_list = msCheckerConfig.black_list
|
|
481
|
-
error_data_path = msCheckerConfig.error_data_path
|
|
482
|
-
is_online = msCheckerConfig.is_online
|
|
483
|
-
nfs_path = msCheckerConfig.nfs_path
|
|
484
|
-
host = msCheckerConfig.host
|
|
485
|
-
port = msCheckerConfig.port
|
|
486
|
-
rank_list = msCheckerConfig.rank_list
|
|
487
|
-
tls_path = msCheckerConfig.tls_path
|
|
488
|
-
if args.config_path:
|
|
489
|
-
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
490
|
-
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
491
|
-
checked_config_path = config_path_checker.common_check()
|
|
492
|
-
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
493
|
-
white_list = task_config.white_list
|
|
494
|
-
black_list = task_config.black_list
|
|
495
|
-
error_data_path = task_config.error_data_path
|
|
496
|
-
is_online = task_config.is_online
|
|
497
|
-
nfs_path = task_config.nfs_path
|
|
498
|
-
host = task_config.host
|
|
499
|
-
port = task_config.port
|
|
500
|
-
rank_list = task_config.rank_list
|
|
501
|
-
tls_path = task_config.tls_path
|
|
502
560
|
|
|
561
|
+
error_data_path = checker_config.error_data_path
|
|
503
562
|
if save_error_data:
|
|
504
563
|
if args.result_csv_path:
|
|
505
564
|
time_info = result_csv_path.split('.')[0].split('_')[-1]
|
|
506
565
|
global UT_ERROR_DATA_DIR
|
|
507
566
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
508
567
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
509
|
-
online_config =
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
568
|
+
online_config = checker_config.get_online_config()
|
|
569
|
+
checked_online_config(online_config)
|
|
570
|
+
config_params = {
|
|
571
|
+
'forward_content': forward_content,
|
|
572
|
+
'backward_content': backward_content,
|
|
573
|
+
'result_csv_path': result_csv_path,
|
|
574
|
+
'details_csv_path': details_csv_path,
|
|
575
|
+
'save_error_data': save_error_data,
|
|
576
|
+
'is_continue_run_ut': args.result_csv_path,
|
|
577
|
+
'real_data_path': real_data_path,
|
|
578
|
+
'error_data_path': error_data_path
|
|
579
|
+
}
|
|
580
|
+
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
513
581
|
run_ut(run_ut_config)
|
|
514
582
|
|
|
515
583
|
|
|
516
584
|
if __name__ == '__main__':
|
|
585
|
+
seed_all()
|
|
517
586
|
_run_ut()
|
|
518
587
|
logger.info("UT task completed.")
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
19
20
|
import re
|
|
20
21
|
import torch
|
|
21
22
|
|
|
@@ -23,8 +24,10 @@ try:
|
|
|
23
24
|
import torch_npu
|
|
24
25
|
except ImportError:
|
|
25
26
|
current_device = "cuda"
|
|
27
|
+
from torch.cuda.amp import autocast
|
|
26
28
|
else:
|
|
27
29
|
current_device = "npu"
|
|
30
|
+
from torch_npu.npu.amp import autocast
|
|
28
31
|
|
|
29
32
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
30
33
|
from msprobe.core.common.file_utils import FileChecker
|
|
@@ -47,11 +50,17 @@ PRECISION_MAPPING = {
|
|
|
47
50
|
}
|
|
48
51
|
|
|
49
52
|
|
|
53
|
+
CpuParams = namedtuple("CpuArgs", ["cpu_args", "cpu_kwargs", "autocast_dtype", "is_autocast"])
|
|
54
|
+
ExecParams = namedtuple("ExecParams", ["api_type", "api_name", "device", "args", "kwargs",
|
|
55
|
+
"is_autocast", "autocast_dtype"])
|
|
56
|
+
|
|
57
|
+
|
|
50
58
|
class BackwardMessage:
|
|
51
59
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
52
60
|
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
61
|
"skip backward."
|
|
54
|
-
NO_BACKWARD_RESULT_MESSAGE = "
|
|
62
|
+
NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
|
|
63
|
+
UNSUPPORT_API_MESSAGE = "This API does not support backward ut, skip backward."
|
|
55
64
|
|
|
56
65
|
|
|
57
66
|
class UtDataInfo:
|
|
@@ -91,7 +100,15 @@ def get_validated_details_csv_path(validated_result_csv_path):
|
|
|
91
100
|
return validated_details_csv_path
|
|
92
101
|
|
|
93
102
|
|
|
94
|
-
def exec_api(
|
|
103
|
+
def exec_api(exec_params):
|
|
104
|
+
api_type = exec_params.api_type
|
|
105
|
+
api_name = exec_params.api_name
|
|
106
|
+
device = exec_params.device
|
|
107
|
+
args = exec_params.args
|
|
108
|
+
kwargs = exec_params.kwargs
|
|
109
|
+
is_autocast = exec_params.is_autocast
|
|
110
|
+
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
+
|
|
95
112
|
if api_type == "Functional":
|
|
96
113
|
torch_api = FunctionalOPTemplate(api_name, str, False)
|
|
97
114
|
if api_type == "Tensor":
|
|
@@ -102,7 +119,11 @@ def exec_api(api_type, api_name, device, args, kwargs):
|
|
|
102
119
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
103
120
|
if api_type == "NPU":
|
|
104
121
|
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
105
|
-
|
|
122
|
+
if is_autocast:
|
|
123
|
+
with autocast(dtype=autocast_dtype):
|
|
124
|
+
out = torch_api.forward(*args, **kwargs)
|
|
125
|
+
else:
|
|
126
|
+
out = torch_api.forward(*args, **kwargs)
|
|
106
127
|
return out
|
|
107
128
|
|
|
108
129
|
|
|
@@ -186,28 +207,48 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
186
207
|
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
187
208
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
188
209
|
if isinstance(arg_in, (list, tuple)):
|
|
189
|
-
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
210
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
211
|
+
arg in arg_in))
|
|
190
212
|
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
191
213
|
return set([arg_in.dtype])
|
|
192
214
|
elif isinstance(arg_in, dict) and check_kwargs:
|
|
193
|
-
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
215
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
216
|
+
v in arg_in.values()))
|
|
194
217
|
return set()
|
|
195
218
|
|
|
196
219
|
raise_dtype = None
|
|
220
|
+
autocast_dtype = None
|
|
221
|
+
is_autocast = False
|
|
197
222
|
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
198
223
|
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
199
224
|
if len(need_raise_dtypes) == 1:
|
|
200
|
-
|
|
225
|
+
origin_dtype = need_raise_dtypes.pop()
|
|
226
|
+
raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32)
|
|
227
|
+
autocast_dtype = origin_dtype
|
|
228
|
+
|
|
201
229
|
elif len(need_raise_dtypes) >= 2:
|
|
202
230
|
raise_dtype = torch.float32
|
|
231
|
+
need_raise_dtypes.discard(torch.float32)
|
|
232
|
+
autocast_dtype = need_raise_dtypes.pop()
|
|
233
|
+
is_autocast = True
|
|
203
234
|
|
|
204
235
|
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
205
236
|
is_detach = api_name not in not_detach_set
|
|
206
237
|
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
207
|
-
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
208
|
-
|
|
238
|
+
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
239
|
+
key, value in input_kwargs.items()}
|
|
240
|
+
cpu_params = CpuParams(cpu_args, cpu_kwargs, autocast_dtype, is_autocast)
|
|
241
|
+
return cpu_params
|
|
209
242
|
|
|
210
243
|
|
|
211
244
|
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
212
245
|
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
|
|
213
246
|
compare.record_results(result_info)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def is_unsupported_api(api_name, is_overflow_check=False):
|
|
250
|
+
split_name = api_name.split(Const.SEP)[0]
|
|
251
|
+
flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
|
|
252
|
+
if flag:
|
|
253
|
+
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
254
|
+
return flag
|