mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -41,6 +41,7 @@ from msprobe.core.common.utils import CompareException
|
|
|
41
41
|
|
|
42
42
|
def split_json_file(input_file, num_splits, filter_api):
|
|
43
43
|
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
44
|
+
input_dir = os.path.dirname(os.path.abspath(input_file))
|
|
44
45
|
if filter_api:
|
|
45
46
|
forward_data = preprocess_forward_content(forward_data)
|
|
46
47
|
for data_name in list(forward_data.keys()):
|
|
@@ -71,7 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
71
72
|
**backward_data
|
|
72
73
|
}
|
|
73
74
|
}
|
|
74
|
-
split_filename = f"temp_part{i}.json"
|
|
75
|
+
split_filename = os.path.join(input_dir, f"temp_part{i}.json")
|
|
75
76
|
save_json(split_filename, temp_data)
|
|
76
77
|
split_files.append(split_filename)
|
|
77
78
|
|
|
@@ -23,12 +23,14 @@ try:
|
|
|
23
23
|
import torch_npu
|
|
24
24
|
except ImportError:
|
|
25
25
|
is_gpu = True
|
|
26
|
+
current_device = "cuda"
|
|
26
27
|
else:
|
|
27
28
|
is_gpu = False
|
|
29
|
+
current_device = "npu"
|
|
28
30
|
import torch
|
|
29
31
|
from tqdm import tqdm
|
|
30
32
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
31
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api, ExecParams
|
|
32
34
|
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
33
35
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
34
36
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
@@ -61,19 +63,33 @@ def check_tensor_overflow(x):
|
|
|
61
63
|
return False
|
|
62
64
|
|
|
63
65
|
|
|
64
|
-
def check_data_overflow(x):
|
|
65
|
-
if isinstance(x, (tuple, list))
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
return False
|
|
66
|
+
def check_data_overflow(x, device):
|
|
67
|
+
if isinstance(x, (tuple, list)):
|
|
68
|
+
if not x:
|
|
69
|
+
return False
|
|
70
|
+
return any(check_data_overflow(item, device) for item in x)
|
|
70
71
|
else:
|
|
71
|
-
|
|
72
|
+
if device == Const.CPU_LOWERCASE:
|
|
73
|
+
return check_tensor_overflow(x)
|
|
74
|
+
else:
|
|
75
|
+
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def is_bool_output(x):
|
|
79
|
+
if isinstance(x, (tuple, list)):
|
|
80
|
+
if not x:
|
|
81
|
+
return False
|
|
82
|
+
return any(is_bool_output(item) for item in x)
|
|
83
|
+
else:
|
|
84
|
+
return isinstance(x, bool)
|
|
72
85
|
|
|
73
86
|
|
|
74
87
|
def run_overflow_check(forward_file):
|
|
75
88
|
logger.info("start UT test")
|
|
76
89
|
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
90
|
+
if real_data_path:
|
|
91
|
+
dump_path = os.path.dirname(forward_file)
|
|
92
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
77
93
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
78
94
|
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
79
95
|
continue
|
|
@@ -87,6 +103,9 @@ def run_overflow_check(forward_file):
|
|
|
87
103
|
elif "expected scalar type Long" in str(err):
|
|
88
104
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
89
105
|
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
106
|
+
elif "could not create a primitive descriptor for a matmul primitive" in str(err):
|
|
107
|
+
logger.warning(f"API {api_name} not support matmul primitive in CPU due to pytorch bug, "
|
|
108
|
+
"so it will be skipped.")
|
|
90
109
|
else:
|
|
91
110
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
92
111
|
|
|
@@ -98,17 +117,26 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
|
98
117
|
if not need_grad:
|
|
99
118
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
100
119
|
% api_full_name)
|
|
120
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
121
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
122
|
+
kwargs[Const.DEVICE] = current_device
|
|
101
123
|
npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
|
|
102
|
-
if kwargs.get(
|
|
103
|
-
del kwargs[
|
|
104
|
-
|
|
105
|
-
|
|
124
|
+
if kwargs.get(Const.DEVICE):
|
|
125
|
+
del kwargs[Const.DEVICE]
|
|
126
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None)
|
|
127
|
+
device_exec_params = ExecParams(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs, False, None)
|
|
128
|
+
out = exec_api(cpu_exec_params)
|
|
129
|
+
npu_out = exec_api(device_exec_params)
|
|
106
130
|
if out is None and npu_out is None:
|
|
107
131
|
logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
|
|
108
132
|
return
|
|
133
|
+
if is_bool_output(out) or is_bool_output(npu_out):
|
|
134
|
+
logger.warning("The output of %s is bool type.This dtype not support overflow, so it will be skipped."
|
|
135
|
+
% api_full_name)
|
|
136
|
+
return
|
|
109
137
|
|
|
110
|
-
cpu_overflow = check_data_overflow(out)
|
|
111
|
-
npu_overflow =
|
|
138
|
+
cpu_overflow = check_data_overflow(out, Const.CPU_LOWERCASE)
|
|
139
|
+
npu_overflow = check_data_overflow(npu_out, Const.NPU_LOWERCASE)
|
|
112
140
|
if cpu_overflow == npu_overflow:
|
|
113
141
|
logger.warning("The %s overflow is a normal overflow." % api_full_name)
|
|
114
142
|
else:
|
|
@@ -31,6 +31,7 @@ except ImportError:
|
|
|
31
31
|
else:
|
|
32
32
|
is_gpu = False
|
|
33
33
|
current_device = "npu"
|
|
34
|
+
|
|
34
35
|
import torch
|
|
35
36
|
from tqdm import tqdm
|
|
36
37
|
|
|
@@ -48,10 +49,12 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
|
48
49
|
from msprobe.pytorch.common.log import logger
|
|
49
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
50
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
51
|
-
from msprobe.core.common.utils import safe_get_value
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
53
|
+
from msprobe.pytorch.common.utils import seed_all
|
|
52
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
53
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
56
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
|
+
ExecParams
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
@@ -61,6 +64,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
61
64
|
|
|
62
65
|
|
|
63
66
|
not_backward_list = ['repeat_interleave']
|
|
67
|
+
unsupported_backward_list = ['masked_select']
|
|
64
68
|
|
|
65
69
|
|
|
66
70
|
tqdm_params = {
|
|
@@ -237,7 +241,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
237
241
|
in_fwd_data_list = []
|
|
238
242
|
backward_message = ''
|
|
239
243
|
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
240
|
-
args, kwargs,
|
|
244
|
+
args, kwargs, output_dtype = get_api_info(api_info_dict, api_name, real_data_path)
|
|
245
|
+
need_grad = check_need_grad(api_info_dict)
|
|
241
246
|
in_fwd_data_list.append(args)
|
|
242
247
|
in_fwd_data_list.append(kwargs)
|
|
243
248
|
need_backward = api_full_name in backward_content
|
|
@@ -248,14 +253,30 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
248
253
|
need_grad = False
|
|
249
254
|
logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
250
255
|
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
256
|
+
if api_name in unsupported_backward_list:
|
|
257
|
+
need_grad = False
|
|
258
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_API_MESSAGE))
|
|
259
|
+
backward_message += BackwardMessage.UNSUPPORT_API_MESSAGE
|
|
251
260
|
need_backward = need_backward and need_grad
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
261
|
+
|
|
262
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
263
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
264
|
+
kwargs[Const.DEVICE] = current_device
|
|
255
265
|
device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
|
|
266
|
+
if kwargs.get(Const.DEVICE):
|
|
267
|
+
del kwargs[Const.DEVICE]
|
|
268
|
+
cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
|
|
269
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
270
|
+
autocast_dtype, is_autocast = cpu_params.autocast_dtype, cpu_params.is_autocast
|
|
271
|
+
if not is_autocast and output_dtype:
|
|
272
|
+
is_autocast = autocast_dtype != output_dtype
|
|
273
|
+
autocast_dtype = output_dtype
|
|
256
274
|
bench_grad_out, device_grad_out = None, None
|
|
257
|
-
|
|
258
|
-
|
|
275
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, autocast_dtype)
|
|
276
|
+
out = exec_api(cpu_exec_params)
|
|
277
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
|
|
278
|
+
autocast_dtype)
|
|
279
|
+
device_out = exec_api(device_exec_params)
|
|
259
280
|
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
260
281
|
ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
|
|
261
282
|
api_setting_dict = get_json_contents(ut_setting_path)
|
|
@@ -273,7 +294,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
273
294
|
}
|
|
274
295
|
grad = gen_args(backward_args, api_name, func_options)
|
|
275
296
|
grad = safe_get_value(grad, 0, "grad")
|
|
276
|
-
|
|
297
|
+
grad_params = generate_cpu_params(grad, {}, False, api_name)
|
|
298
|
+
bench_grad = grad_params.cpu_args
|
|
277
299
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
278
300
|
device_grad = grad.clone().detach().to(current_device)
|
|
279
301
|
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
@@ -300,13 +322,18 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
300
322
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
301
323
|
|
|
302
324
|
|
|
303
|
-
def
|
|
304
|
-
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
325
|
+
def check_need_grad(api_info_dict):
|
|
305
326
|
need_grad = True
|
|
306
|
-
if api_info_dict.get(
|
|
327
|
+
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
307
328
|
need_grad = False
|
|
308
|
-
|
|
309
|
-
|
|
329
|
+
return need_grad
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def get_api_info(api_info_dict, api_name, real_data_path):
|
|
333
|
+
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
334
|
+
need_grad = check_need_grad(api_info_dict)
|
|
335
|
+
args, kwargs, output_dtype = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
|
|
336
|
+
return args, kwargs, output_dtype
|
|
310
337
|
|
|
311
338
|
|
|
312
339
|
def need_to_backward(grad_index, out):
|
|
@@ -323,15 +350,25 @@ def run_backward(args, grad, grad_index, out):
|
|
|
323
350
|
out[grad_index].backward(grad)
|
|
324
351
|
else:
|
|
325
352
|
out.backward(grad)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
if isinstance(arg, torch.Tensor):
|
|
329
|
-
args_grad.append(arg.grad)
|
|
330
|
-
grad_out = args_grad
|
|
353
|
+
|
|
354
|
+
grad_out = extract_tensors_grad(args)
|
|
331
355
|
|
|
332
356
|
return grad_out
|
|
333
357
|
|
|
334
358
|
|
|
359
|
+
def extract_tensors_grad(args, depth=0):
|
|
360
|
+
if depth > Const.MAX_DEPTH:
|
|
361
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
362
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
363
|
+
grads = []
|
|
364
|
+
for arg in args:
|
|
365
|
+
if isinstance(arg, torch.Tensor):
|
|
366
|
+
grads.append(arg.grad)
|
|
367
|
+
elif isinstance(arg, (list, tuple)):
|
|
368
|
+
grads.extend(extract_tensors_grad(arg, depth+1))
|
|
369
|
+
return grads
|
|
370
|
+
|
|
371
|
+
|
|
335
372
|
def initialize_save_error_data(error_data_path):
|
|
336
373
|
create_directory(error_data_path)
|
|
337
374
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
@@ -479,6 +516,10 @@ def run_ut_command(args):
|
|
|
479
516
|
|
|
480
517
|
if not is_gpu:
|
|
481
518
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
519
|
+
if args.jit_compile:
|
|
520
|
+
torch.npu.config.allow_internal_format = True
|
|
521
|
+
else:
|
|
522
|
+
torch.npu.config.allow_internal_format = False
|
|
482
523
|
used_device = current_device + ":" + str(args.device_id[0])
|
|
483
524
|
try:
|
|
484
525
|
if is_gpu:
|
|
@@ -497,6 +538,9 @@ def run_ut_command(args):
|
|
|
497
538
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
498
539
|
checked_api_info = api_info_file_checker.common_check()
|
|
499
540
|
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
541
|
+
if real_data_path:
|
|
542
|
+
dump_path = os.path.dirname(checked_api_info)
|
|
543
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
500
544
|
if args.filter_api:
|
|
501
545
|
logger.info("Start filtering the api in the api_info_file.")
|
|
502
546
|
forward_content = preprocess_forward_content(forward_content)
|
|
@@ -538,5 +582,6 @@ def run_ut_command(args):
|
|
|
538
582
|
|
|
539
583
|
|
|
540
584
|
if __name__ == '__main__':
|
|
585
|
+
seed_all()
|
|
541
586
|
_run_ut()
|
|
542
587
|
logger.info("UT task completed.")
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
19
20
|
import re
|
|
20
21
|
import torch
|
|
21
22
|
|
|
@@ -23,8 +24,10 @@ try:
|
|
|
23
24
|
import torch_npu
|
|
24
25
|
except ImportError:
|
|
25
26
|
current_device = "cuda"
|
|
27
|
+
from torch.cuda.amp import autocast
|
|
26
28
|
else:
|
|
27
29
|
current_device = "npu"
|
|
30
|
+
from torch_npu.npu.amp import autocast
|
|
28
31
|
|
|
29
32
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
30
33
|
from msprobe.core.common.file_utils import FileChecker
|
|
@@ -47,11 +50,17 @@ PRECISION_MAPPING = {
|
|
|
47
50
|
}
|
|
48
51
|
|
|
49
52
|
|
|
53
|
+
CpuParams = namedtuple("CpuArgs", ["cpu_args", "cpu_kwargs", "autocast_dtype", "is_autocast"])
|
|
54
|
+
ExecParams = namedtuple("ExecParams", ["api_type", "api_name", "device", "args", "kwargs",
|
|
55
|
+
"is_autocast", "autocast_dtype"])
|
|
56
|
+
|
|
57
|
+
|
|
50
58
|
class BackwardMessage:
|
|
51
59
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
52
60
|
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
61
|
"skip backward."
|
|
54
62
|
NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
|
|
63
|
+
UNSUPPORT_API_MESSAGE = "This API does not support backward ut, skip backward."
|
|
55
64
|
|
|
56
65
|
|
|
57
66
|
class UtDataInfo:
|
|
@@ -91,7 +100,15 @@ def get_validated_details_csv_path(validated_result_csv_path):
|
|
|
91
100
|
return validated_details_csv_path
|
|
92
101
|
|
|
93
102
|
|
|
94
|
-
def exec_api(
|
|
103
|
+
def exec_api(exec_params):
|
|
104
|
+
api_type = exec_params.api_type
|
|
105
|
+
api_name = exec_params.api_name
|
|
106
|
+
device = exec_params.device
|
|
107
|
+
args = exec_params.args
|
|
108
|
+
kwargs = exec_params.kwargs
|
|
109
|
+
is_autocast = exec_params.is_autocast
|
|
110
|
+
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
+
|
|
95
112
|
if api_type == "Functional":
|
|
96
113
|
torch_api = FunctionalOPTemplate(api_name, str, False)
|
|
97
114
|
if api_type == "Tensor":
|
|
@@ -102,7 +119,11 @@ def exec_api(api_type, api_name, device, args, kwargs):
|
|
|
102
119
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
103
120
|
if api_type == "NPU":
|
|
104
121
|
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
105
|
-
|
|
122
|
+
if is_autocast:
|
|
123
|
+
with autocast(dtype=autocast_dtype):
|
|
124
|
+
out = torch_api.forward(*args, **kwargs)
|
|
125
|
+
else:
|
|
126
|
+
out = torch_api.forward(*args, **kwargs)
|
|
106
127
|
return out
|
|
107
128
|
|
|
108
129
|
|
|
@@ -196,19 +217,28 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
196
217
|
return set()
|
|
197
218
|
|
|
198
219
|
raise_dtype = None
|
|
220
|
+
autocast_dtype = None
|
|
221
|
+
is_autocast = False
|
|
199
222
|
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
200
223
|
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
201
224
|
if len(need_raise_dtypes) == 1:
|
|
202
|
-
|
|
225
|
+
origin_dtype = need_raise_dtypes.pop()
|
|
226
|
+
raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32)
|
|
227
|
+
autocast_dtype = origin_dtype
|
|
228
|
+
|
|
203
229
|
elif len(need_raise_dtypes) >= 2:
|
|
204
230
|
raise_dtype = torch.float32
|
|
231
|
+
need_raise_dtypes.discard(torch.float32)
|
|
232
|
+
autocast_dtype = need_raise_dtypes.pop()
|
|
233
|
+
is_autocast = True
|
|
205
234
|
|
|
206
235
|
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
207
236
|
is_detach = api_name not in not_detach_set
|
|
208
237
|
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
209
238
|
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
210
239
|
key, value in input_kwargs.items()}
|
|
211
|
-
|
|
240
|
+
cpu_params = CpuParams(cpu_args, cpu_kwargs, autocast_dtype, is_autocast)
|
|
241
|
+
return cpu_params
|
|
212
242
|
|
|
213
243
|
|
|
214
244
|
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
@@ -24,7 +24,7 @@ from msprobe.core.common.const import Const, CompareConst
|
|
|
24
24
|
from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
|
|
25
25
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
|
|
26
26
|
binary_standard_api, absolute_standard_api
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
|
|
28
28
|
from msprobe.pytorch.common.log import logger
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
@@ -92,8 +92,10 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
|
|
|
92
92
|
|
|
93
93
|
try:
|
|
94
94
|
# NPU vs CPU
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
|
|
96
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
97
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
|
|
98
|
+
cpu_out = exec_api(cpu_exec_params)
|
|
97
99
|
npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
|
|
98
100
|
npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
|
|
99
101
|
npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
|
|
@@ -30,6 +30,7 @@
|
|
|
30
30
|
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
+
from collections import namedtuple
|
|
33
34
|
import torch
|
|
34
35
|
import numpy as np
|
|
35
36
|
from einops import rearrange
|
|
@@ -54,6 +55,14 @@ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即
|
|
|
54
55
|
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
FaForwardParams = namedtuple("FaForwardParams",
|
|
59
|
+
["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"])
|
|
60
|
+
FaBackwardParams = namedtuple("FaBackwardParams",
|
|
61
|
+
["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"])
|
|
62
|
+
RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
|
|
63
|
+
["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"])
|
|
64
|
+
|
|
65
|
+
|
|
57
66
|
def softmax_forward(x):
|
|
58
67
|
x_max = torch.max(x, dim=-1, keepdims=True)[0]
|
|
59
68
|
x_sub = x.sub(x_max)
|
|
@@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale):
|
|
|
99
108
|
return qk
|
|
100
109
|
|
|
101
110
|
|
|
102
|
-
def fusion_attention_forward(
|
|
111
|
+
def fusion_attention_forward(forward_params):
|
|
112
|
+
q = forward_params.q
|
|
113
|
+
k = forward_params.k
|
|
114
|
+
v = forward_params.v
|
|
115
|
+
drop_mask = forward_params.drop_mask
|
|
116
|
+
atten_mask = forward_params.atten_mask
|
|
117
|
+
pse = forward_params.pse
|
|
118
|
+
scale = forward_params.scale
|
|
119
|
+
keep_prob = forward_params.keep_prob
|
|
103
120
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
104
121
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
105
122
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr
|
|
|
110
127
|
return y, softmax_max, softmax_sum
|
|
111
128
|
|
|
112
129
|
|
|
113
|
-
def fusion_attention_backward(
|
|
130
|
+
def fusion_attention_backward(backward_params):
|
|
131
|
+
dx = backward_params.dx
|
|
132
|
+
q = backward_params.q
|
|
133
|
+
k = backward_params.k
|
|
134
|
+
v = backward_params.v
|
|
135
|
+
softmax_res = backward_params.softmax_res
|
|
136
|
+
drop_mask = backward_params.drop_mask
|
|
137
|
+
pse = backward_params.pse
|
|
138
|
+
scale = backward_params.scale
|
|
139
|
+
keep_prob = backward_params.keep_prob
|
|
114
140
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
115
141
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
116
142
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -368,11 +394,18 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
|
|
|
368
394
|
return softmax_res
|
|
369
395
|
|
|
370
396
|
|
|
371
|
-
def rebuild_softmax_by_max_sum(
|
|
397
|
+
def rebuild_softmax_by_max_sum(softmax_params):
|
|
372
398
|
"""
|
|
373
399
|
attention = softmax(QK^T/sqrt(d))V
|
|
374
400
|
softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
|
|
375
401
|
"""
|
|
402
|
+
q = softmax_params.q
|
|
403
|
+
k = softmax_params.k
|
|
404
|
+
atten_mask = softmax_params.atten_mask
|
|
405
|
+
pse = softmax_params.pse
|
|
406
|
+
scale = softmax_params.scale
|
|
407
|
+
softmax_max = softmax_params.softmax_max
|
|
408
|
+
softmax_sum = softmax_params.softmax_sum
|
|
376
409
|
logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
|
|
377
410
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
378
411
|
if softmax_max.shape[-1] == 0:
|
|
@@ -502,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
502
535
|
key = convert_to_bnsd(key, n2, input_layout)
|
|
503
536
|
value = convert_to_bnsd(value, n2, input_layout)
|
|
504
537
|
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
pse=pse, scale=scale,
|
|
508
|
-
keep_prob=keep_prob)
|
|
538
|
+
forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob)
|
|
539
|
+
out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
|
|
509
540
|
if out_golden.dim() == 5:
|
|
510
541
|
out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
|
|
511
542
|
out_golden.size(4))
|
|
@@ -546,9 +577,10 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
546
577
|
if SOFTMAX_BUILD_MODE == "QKV":
|
|
547
578
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
548
579
|
else:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
580
|
+
softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
|
|
581
|
+
softmax_res = rebuild_softmax_by_max_sum(softmax_params)
|
|
582
|
+
backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
|
|
583
|
+
dq, dk, dv = fusion_attention_backward(backward_params)
|
|
552
584
|
|
|
553
585
|
# N不等长适配by cdy
|
|
554
586
|
if not (n1 == n2):
|
|
@@ -24,7 +24,8 @@ def parse_json_info_forward_backward(json_path):
|
|
|
24
24
|
real_data_path = dump_json.get("dump_data_dir")
|
|
25
25
|
dump_data = dump_json.get("data")
|
|
26
26
|
if dump_data is None:
|
|
27
|
-
raise ParseJsonException(ParseJsonException.InvalidDumpJson,
|
|
27
|
+
raise ParseJsonException(ParseJsonException.InvalidDumpJson,
|
|
28
|
+
"something wrong with dump, no data found in dump.json")
|
|
28
29
|
if not dump_data:
|
|
29
30
|
logger.warning("data field is empty, no overflow data found.")
|
|
30
31
|
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -105,8 +105,49 @@ def get_rank_if_initialized():
|
|
|
105
105
|
raise DistributedNotInitializedError("torch distributed environment is not initialized")
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def
|
|
109
|
-
|
|
108
|
+
def remove_dropout():
|
|
109
|
+
if torch.__version__ > "1.8":
|
|
110
|
+
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
111
|
+
import torch.nn.functional as F
|
|
112
|
+
from torch import _VF
|
|
113
|
+
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
114
|
+
|
|
115
|
+
def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
116
|
+
inplace: bool = False) -> torch.Tensor:
|
|
117
|
+
if has_torch_function_unary(input_tensor):
|
|
118
|
+
return handle_torch_function(
|
|
119
|
+
function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
120
|
+
if p < 0.0 or p > 1.0:
|
|
121
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
122
|
+
return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
|
|
123
|
+
|
|
124
|
+
def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
125
|
+
inplace: bool = False) -> torch.Tensor:
|
|
126
|
+
if has_torch_function_unary(input_tensor):
|
|
127
|
+
return handle_torch_function(
|
|
128
|
+
function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
129
|
+
if p < 0.0 or p > 1.0:
|
|
130
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
131
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
132
|
+
0., training)
|
|
133
|
+
|
|
134
|
+
def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
135
|
+
inplace: bool = False) -> torch.Tensor:
|
|
136
|
+
if has_torch_function_unary(input_tensor):
|
|
137
|
+
return handle_torch_function(
|
|
138
|
+
function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
139
|
+
if p < 0.0 or p > 1.0:
|
|
140
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
141
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
142
|
+
0., training)
|
|
143
|
+
|
|
144
|
+
F.dropout = function_dropout
|
|
145
|
+
F.dropout2d = function_dropout2d
|
|
146
|
+
F.dropout3d = function_dropout3d
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
150
|
+
check_seed_all(seed, mode, rm_dropout)
|
|
110
151
|
try:
|
|
111
152
|
random.seed(seed)
|
|
112
153
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
@@ -126,6 +167,8 @@ def seed_all(seed=1234, mode=False):
|
|
|
126
167
|
else:
|
|
127
168
|
torch_npu.npu.manual_seed_all(seed)
|
|
128
169
|
torch_npu.npu.manual_seed(seed)
|
|
170
|
+
if rm_dropout:
|
|
171
|
+
remove_dropout()
|
|
129
172
|
except Exception as e:
|
|
130
173
|
logger.error(f"There is an unexpected error while determinating randomness. {e}")
|
|
131
174
|
|
|
@@ -14,52 +14,40 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
18
|
-
check_configuration_param, set_dump_path, get_dump_mode
|
|
19
|
-
from msprobe.core.common.file_utils import create_directory
|
|
17
|
+
|
|
20
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory
|
|
20
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
21
|
+
set_dump_path
|
|
22
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
23
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
|
|
21
24
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
23
|
-
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator, compare
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
27
|
-
if kwargs.get(
|
|
29
|
+
if kwargs.get("suffix"):
|
|
28
30
|
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
29
31
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
30
|
-
|
|
31
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
32
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
33
|
-
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
32
|
+
is_print_compare_log = kwargs.get("is_print_compare_log", True)
|
|
34
33
|
# get the ranks and match by order
|
|
35
34
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
36
35
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
37
36
|
if len(npu_ranks) != len(bench_ranks):
|
|
38
|
-
logger.error(
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
logger.error(
|
|
38
|
+
"The number of ranks in the two runs are different. "
|
|
39
|
+
"Unable to match the ranks. "
|
|
40
|
+
"Please use another folder to compare or use compare() api and manually match the ranks.")
|
|
41
41
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
42
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
43
|
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
44
44
|
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
45
45
|
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
46
46
|
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
47
|
-
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
48
47
|
|
|
49
48
|
dump_result_param = {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
'is_print_compare_log': is_print_compare_log
|
|
49
|
+
"npu_json_path": npu_path,
|
|
50
|
+
"bench_json_path": bench_path,
|
|
51
|
+
"is_print_compare_log": is_print_compare_log
|
|
54
52
|
}
|
|
55
|
-
|
|
56
|
-
set_dump_path(dump_result_param)
|
|
57
|
-
dump_mode = get_dump_mode(dump_result_param)
|
|
58
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
|
|
59
|
-
create_directory(output_path)
|
|
60
|
-
check_compare_param(dump_result_param, output_path, dump_mode)
|
|
61
|
-
except (CompareException, FileCheckException) as error:
|
|
62
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
|
-
raise CompareException(error.code) from error
|
|
64
|
-
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
|
|
53
|
+
compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|