mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -41,6 +41,7 @@ from msprobe.core.common.utils import CompareException
|
|
|
41
41
|
|
|
42
42
|
def split_json_file(input_file, num_splits, filter_api):
|
|
43
43
|
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
44
|
+
input_dir = os.path.dirname(os.path.abspath(input_file))
|
|
44
45
|
if filter_api:
|
|
45
46
|
forward_data = preprocess_forward_content(forward_data)
|
|
46
47
|
for data_name in list(forward_data.keys()):
|
|
@@ -71,7 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
71
72
|
**backward_data
|
|
72
73
|
}
|
|
73
74
|
}
|
|
74
|
-
split_filename = f"temp_part{i}.json"
|
|
75
|
+
split_filename = os.path.join(input_dir, f"temp_part{i}.json")
|
|
75
76
|
save_json(split_filename, temp_data)
|
|
76
77
|
split_files.append(split_filename)
|
|
77
78
|
|
|
@@ -23,12 +23,14 @@ try:
|
|
|
23
23
|
import torch_npu
|
|
24
24
|
except ImportError:
|
|
25
25
|
is_gpu = True
|
|
26
|
+
current_device = "cuda"
|
|
26
27
|
else:
|
|
27
28
|
is_gpu = False
|
|
29
|
+
current_device = "npu"
|
|
28
30
|
import torch
|
|
29
31
|
from tqdm import tqdm
|
|
30
32
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
|
|
31
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
|
|
33
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api, ExecParams
|
|
32
34
|
from msprobe.core.common.file_utils import check_link, FileChecker
|
|
33
35
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
34
36
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
@@ -61,19 +63,33 @@ def check_tensor_overflow(x):
|
|
|
61
63
|
return False
|
|
62
64
|
|
|
63
65
|
|
|
64
|
-
def check_data_overflow(x):
|
|
65
|
-
if isinstance(x, (tuple, list))
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
return False
|
|
66
|
+
def check_data_overflow(x, device):
|
|
67
|
+
if isinstance(x, (tuple, list)):
|
|
68
|
+
if not x:
|
|
69
|
+
return False
|
|
70
|
+
return any(check_data_overflow(item, device) for item in x)
|
|
70
71
|
else:
|
|
71
|
-
|
|
72
|
+
if device == Const.CPU_LOWERCASE:
|
|
73
|
+
return check_tensor_overflow(x)
|
|
74
|
+
else:
|
|
75
|
+
return torch_npu.npu.utils.npu_check_overflow(x)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def is_bool_output(x):
|
|
79
|
+
if isinstance(x, (tuple, list)):
|
|
80
|
+
if not x:
|
|
81
|
+
return False
|
|
82
|
+
return any(is_bool_output(item) for item in x)
|
|
83
|
+
else:
|
|
84
|
+
return isinstance(x, bool)
|
|
72
85
|
|
|
73
86
|
|
|
74
87
|
def run_overflow_check(forward_file):
|
|
75
88
|
logger.info("start UT test")
|
|
76
89
|
forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
|
|
90
|
+
if real_data_path:
|
|
91
|
+
dump_path = os.path.dirname(forward_file)
|
|
92
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
77
93
|
for api_full_name, api_info_dict in tqdm(forward_content.items()):
|
|
78
94
|
if is_unsupported_api(api_full_name, is_overflow_check=True):
|
|
79
95
|
continue
|
|
@@ -87,6 +103,9 @@ def run_overflow_check(forward_file):
|
|
|
87
103
|
elif "expected scalar type Long" in str(err):
|
|
88
104
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
89
105
|
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
106
|
+
elif "could not create a primitive descriptor for a matmul primitive" in str(err):
|
|
107
|
+
logger.warning(f"API {api_name} not support matmul primitive in CPU due to pytorch bug, "
|
|
108
|
+
"so it will be skipped.")
|
|
90
109
|
else:
|
|
91
110
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
92
111
|
|
|
@@ -98,17 +117,26 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path):
|
|
|
98
117
|
if not need_grad:
|
|
99
118
|
logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
|
|
100
119
|
% api_full_name)
|
|
120
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
121
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
122
|
+
kwargs[Const.DEVICE] = current_device
|
|
101
123
|
npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
|
|
102
|
-
if kwargs.get(
|
|
103
|
-
del kwargs[
|
|
104
|
-
|
|
105
|
-
|
|
124
|
+
if kwargs.get(Const.DEVICE):
|
|
125
|
+
del kwargs[Const.DEVICE]
|
|
126
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None)
|
|
127
|
+
device_exec_params = ExecParams(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs, False, None)
|
|
128
|
+
out = exec_api(cpu_exec_params)
|
|
129
|
+
npu_out = exec_api(device_exec_params)
|
|
106
130
|
if out is None and npu_out is None:
|
|
107
131
|
logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
|
|
108
132
|
return
|
|
133
|
+
if is_bool_output(out) or is_bool_output(npu_out):
|
|
134
|
+
logger.warning("The output of %s is bool type.This dtype not support overflow, so it will be skipped."
|
|
135
|
+
% api_full_name)
|
|
136
|
+
return
|
|
109
137
|
|
|
110
|
-
cpu_overflow = check_data_overflow(out)
|
|
111
|
-
npu_overflow =
|
|
138
|
+
cpu_overflow = check_data_overflow(out, Const.CPU_LOWERCASE)
|
|
139
|
+
npu_overflow = check_data_overflow(npu_out, Const.NPU_LOWERCASE)
|
|
112
140
|
if cpu_overflow == npu_overflow:
|
|
113
141
|
logger.warning("The %s overflow is a normal overflow." % api_full_name)
|
|
114
142
|
else:
|
|
@@ -31,6 +31,7 @@ except ImportError:
|
|
|
31
31
|
else:
|
|
32
32
|
is_gpu = False
|
|
33
33
|
current_device = "npu"
|
|
34
|
+
|
|
34
35
|
import torch
|
|
35
36
|
from tqdm import tqdm
|
|
36
37
|
|
|
@@ -48,10 +49,12 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
|
48
49
|
from msprobe.pytorch.common.log import logger
|
|
49
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
50
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
51
|
-
from msprobe.core.common.utils import safe_get_value
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
53
|
+
from msprobe.pytorch.common.utils import seed_all
|
|
52
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
53
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
56
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
|
+
ExecParams
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
@@ -61,6 +64,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
61
64
|
|
|
62
65
|
|
|
63
66
|
not_backward_list = ['repeat_interleave']
|
|
67
|
+
unsupported_backward_list = ['masked_select']
|
|
64
68
|
|
|
65
69
|
|
|
66
70
|
tqdm_params = {
|
|
@@ -237,7 +241,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
237
241
|
in_fwd_data_list = []
|
|
238
242
|
backward_message = ''
|
|
239
243
|
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
240
|
-
args, kwargs,
|
|
244
|
+
args, kwargs, output_dtype = get_api_info(api_info_dict, api_name, real_data_path)
|
|
245
|
+
need_grad = check_need_grad(api_info_dict)
|
|
241
246
|
in_fwd_data_list.append(args)
|
|
242
247
|
in_fwd_data_list.append(kwargs)
|
|
243
248
|
need_backward = api_full_name in backward_content
|
|
@@ -248,14 +253,30 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
248
253
|
need_grad = False
|
|
249
254
|
logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
250
255
|
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
256
|
+
if api_name in unsupported_backward_list:
|
|
257
|
+
need_grad = False
|
|
258
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_API_MESSAGE))
|
|
259
|
+
backward_message += BackwardMessage.UNSUPPORT_API_MESSAGE
|
|
251
260
|
need_backward = need_backward and need_grad
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
261
|
+
|
|
262
|
+
device_info_kwargs = kwargs.get(Const.DEVICE)
|
|
263
|
+
if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
|
|
264
|
+
kwargs[Const.DEVICE] = current_device
|
|
255
265
|
device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
|
|
266
|
+
if kwargs.get(Const.DEVICE):
|
|
267
|
+
del kwargs[Const.DEVICE]
|
|
268
|
+
cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
|
|
269
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
270
|
+
autocast_dtype, is_autocast = cpu_params.autocast_dtype, cpu_params.is_autocast
|
|
271
|
+
if not is_autocast and output_dtype:
|
|
272
|
+
is_autocast = autocast_dtype != output_dtype
|
|
273
|
+
autocast_dtype = output_dtype
|
|
256
274
|
bench_grad_out, device_grad_out = None, None
|
|
257
|
-
|
|
258
|
-
|
|
275
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, autocast_dtype)
|
|
276
|
+
out = exec_api(cpu_exec_params)
|
|
277
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
|
|
278
|
+
autocast_dtype)
|
|
279
|
+
device_out = exec_api(device_exec_params)
|
|
259
280
|
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
260
281
|
ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
|
|
261
282
|
api_setting_dict = get_json_contents(ut_setting_path)
|
|
@@ -273,7 +294,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
273
294
|
}
|
|
274
295
|
grad = gen_args(backward_args, api_name, func_options)
|
|
275
296
|
grad = safe_get_value(grad, 0, "grad")
|
|
276
|
-
|
|
297
|
+
grad_params = generate_cpu_params(grad, {}, False, api_name)
|
|
298
|
+
bench_grad = grad_params.cpu_args
|
|
277
299
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
278
300
|
device_grad = grad.clone().detach().to(current_device)
|
|
279
301
|
device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
|
|
@@ -300,13 +322,18 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
300
322
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
301
323
|
|
|
302
324
|
|
|
303
|
-
def
|
|
304
|
-
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
325
|
+
def check_need_grad(api_info_dict):
|
|
305
326
|
need_grad = True
|
|
306
|
-
if api_info_dict.get(
|
|
327
|
+
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
307
328
|
need_grad = False
|
|
308
|
-
|
|
309
|
-
|
|
329
|
+
return need_grad
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def get_api_info(api_info_dict, api_name, real_data_path):
|
|
333
|
+
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
|
|
334
|
+
need_grad = check_need_grad(api_info_dict)
|
|
335
|
+
args, kwargs, output_dtype = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
|
|
336
|
+
return args, kwargs, output_dtype
|
|
310
337
|
|
|
311
338
|
|
|
312
339
|
def need_to_backward(grad_index, out):
|
|
@@ -323,15 +350,25 @@ def run_backward(args, grad, grad_index, out):
|
|
|
323
350
|
out[grad_index].backward(grad)
|
|
324
351
|
else:
|
|
325
352
|
out.backward(grad)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
if isinstance(arg, torch.Tensor):
|
|
329
|
-
args_grad.append(arg.grad)
|
|
330
|
-
grad_out = args_grad
|
|
353
|
+
|
|
354
|
+
grad_out = extract_tensors_grad(args)
|
|
331
355
|
|
|
332
356
|
return grad_out
|
|
333
357
|
|
|
334
358
|
|
|
359
|
+
def extract_tensors_grad(args, depth=0):
|
|
360
|
+
if depth > Const.MAX_DEPTH:
|
|
361
|
+
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
362
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
363
|
+
grads = []
|
|
364
|
+
for arg in args:
|
|
365
|
+
if isinstance(arg, torch.Tensor):
|
|
366
|
+
grads.append(arg.grad)
|
|
367
|
+
elif isinstance(arg, (list, tuple)):
|
|
368
|
+
grads.extend(extract_tensors_grad(arg, depth+1))
|
|
369
|
+
return grads
|
|
370
|
+
|
|
371
|
+
|
|
335
372
|
def initialize_save_error_data(error_data_path):
|
|
336
373
|
create_directory(error_data_path)
|
|
337
374
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
@@ -479,6 +516,10 @@ def run_ut_command(args):
|
|
|
479
516
|
|
|
480
517
|
if not is_gpu:
|
|
481
518
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
519
|
+
if args.jit_compile:
|
|
520
|
+
torch.npu.config.allow_internal_format = True
|
|
521
|
+
else:
|
|
522
|
+
torch.npu.config.allow_internal_format = False
|
|
482
523
|
used_device = current_device + ":" + str(args.device_id[0])
|
|
483
524
|
try:
|
|
484
525
|
if is_gpu:
|
|
@@ -497,6 +538,9 @@ def run_ut_command(args):
|
|
|
497
538
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
498
539
|
checked_api_info = api_info_file_checker.common_check()
|
|
499
540
|
forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
|
|
541
|
+
if real_data_path:
|
|
542
|
+
dump_path = os.path.dirname(checked_api_info)
|
|
543
|
+
real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
|
|
500
544
|
if args.filter_api:
|
|
501
545
|
logger.info("Start filtering the api in the api_info_file.")
|
|
502
546
|
forward_content = preprocess_forward_content(forward_content)
|
|
@@ -538,5 +582,6 @@ def run_ut_command(args):
|
|
|
538
582
|
|
|
539
583
|
|
|
540
584
|
if __name__ == '__main__':
|
|
585
|
+
seed_all()
|
|
541
586
|
_run_ut()
|
|
542
587
|
logger.info("UT task completed.")
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
19
20
|
import re
|
|
20
21
|
import torch
|
|
21
22
|
|
|
@@ -23,8 +24,10 @@ try:
|
|
|
23
24
|
import torch_npu
|
|
24
25
|
except ImportError:
|
|
25
26
|
current_device = "cuda"
|
|
27
|
+
from torch.cuda.amp import autocast
|
|
26
28
|
else:
|
|
27
29
|
current_device = "npu"
|
|
30
|
+
from torch_npu.npu.amp import autocast
|
|
28
31
|
|
|
29
32
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
30
33
|
from msprobe.core.common.file_utils import FileChecker
|
|
@@ -47,11 +50,17 @@ PRECISION_MAPPING = {
|
|
|
47
50
|
}
|
|
48
51
|
|
|
49
52
|
|
|
53
|
+
CpuParams = namedtuple("CpuArgs", ["cpu_args", "cpu_kwargs", "autocast_dtype", "is_autocast"])
|
|
54
|
+
ExecParams = namedtuple("ExecParams", ["api_type", "api_name", "device", "args", "kwargs",
|
|
55
|
+
"is_autocast", "autocast_dtype"])
|
|
56
|
+
|
|
57
|
+
|
|
50
58
|
class BackwardMessage:
|
|
51
59
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
52
60
|
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
61
|
"skip backward."
|
|
54
62
|
NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
|
|
63
|
+
UNSUPPORT_API_MESSAGE = "This API does not support backward ut, skip backward."
|
|
55
64
|
|
|
56
65
|
|
|
57
66
|
class UtDataInfo:
|
|
@@ -91,7 +100,15 @@ def get_validated_details_csv_path(validated_result_csv_path):
|
|
|
91
100
|
return validated_details_csv_path
|
|
92
101
|
|
|
93
102
|
|
|
94
|
-
def exec_api(
|
|
103
|
+
def exec_api(exec_params):
|
|
104
|
+
api_type = exec_params.api_type
|
|
105
|
+
api_name = exec_params.api_name
|
|
106
|
+
device = exec_params.device
|
|
107
|
+
args = exec_params.args
|
|
108
|
+
kwargs = exec_params.kwargs
|
|
109
|
+
is_autocast = exec_params.is_autocast
|
|
110
|
+
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
+
|
|
95
112
|
if api_type == "Functional":
|
|
96
113
|
torch_api = FunctionalOPTemplate(api_name, str, False)
|
|
97
114
|
if api_type == "Tensor":
|
|
@@ -102,7 +119,11 @@ def exec_api(api_type, api_name, device, args, kwargs):
|
|
|
102
119
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
103
120
|
if api_type == "NPU":
|
|
104
121
|
torch_api = NpuOPTemplate(api_name, None, False, device)
|
|
105
|
-
|
|
122
|
+
if is_autocast:
|
|
123
|
+
with autocast(dtype=autocast_dtype):
|
|
124
|
+
out = torch_api.forward(*args, **kwargs)
|
|
125
|
+
else:
|
|
126
|
+
out = torch_api.forward(*args, **kwargs)
|
|
106
127
|
return out
|
|
107
128
|
|
|
108
129
|
|
|
@@ -196,19 +217,28 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
196
217
|
return set()
|
|
197
218
|
|
|
198
219
|
raise_dtype = None
|
|
220
|
+
autocast_dtype = None
|
|
221
|
+
is_autocast = False
|
|
199
222
|
need_raise_dtypes = recursive_find_dtypes(input_args)
|
|
200
223
|
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
|
|
201
224
|
if len(need_raise_dtypes) == 1:
|
|
202
|
-
|
|
225
|
+
origin_dtype = need_raise_dtypes.pop()
|
|
226
|
+
raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32)
|
|
227
|
+
autocast_dtype = origin_dtype
|
|
228
|
+
|
|
203
229
|
elif len(need_raise_dtypes) >= 2:
|
|
204
230
|
raise_dtype = torch.float32
|
|
231
|
+
need_raise_dtypes.discard(torch.float32)
|
|
232
|
+
autocast_dtype = need_raise_dtypes.pop()
|
|
233
|
+
is_autocast = True
|
|
205
234
|
|
|
206
235
|
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
207
236
|
is_detach = api_name not in not_detach_set
|
|
208
237
|
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
209
238
|
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
210
239
|
key, value in input_kwargs.items()}
|
|
211
|
-
|
|
240
|
+
cpu_params = CpuParams(cpu_args, cpu_kwargs, autocast_dtype, is_autocast)
|
|
241
|
+
return cpu_params
|
|
212
242
|
|
|
213
243
|
|
|
214
244
|
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
@@ -24,7 +24,7 @@ from msprobe.core.common.const import Const, CompareConst
|
|
|
24
24
|
from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
|
|
25
25
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
|
|
26
26
|
binary_standard_api, absolute_standard_api
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
|
|
28
28
|
from msprobe.pytorch.common.log import logger
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
@@ -92,8 +92,10 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
|
|
|
92
92
|
|
|
93
93
|
try:
|
|
94
94
|
# NPU vs CPU
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
|
|
96
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
97
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
|
|
98
|
+
cpu_out = exec_api(cpu_exec_params)
|
|
97
99
|
npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
|
|
98
100
|
npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
|
|
99
101
|
npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t'])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _output_m_compute(m, beta1_broad, grad):
|
|
24
|
+
"""
|
|
25
|
+
_output_m_compute
|
|
26
|
+
do compute m_t = m + (beta1 - 1) * (m - grad)
|
|
27
|
+
"""
|
|
28
|
+
input_dtype = m.dtype
|
|
29
|
+
|
|
30
|
+
sneg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
31
|
+
sneg_one = sneg_one.to(beta1_broad.device)
|
|
32
|
+
|
|
33
|
+
# `formula; beta1 -1`
|
|
34
|
+
vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
|
|
35
|
+
|
|
36
|
+
# `formula; m - grad`
|
|
37
|
+
vsub_m_grad = torch.sub(m, grad)
|
|
38
|
+
|
|
39
|
+
# `formula; (beta1 - 1) * (m - grad)`
|
|
40
|
+
vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
|
|
41
|
+
|
|
42
|
+
# `formula; m_t = m + (beta1 - 1) * (m - grad)`
|
|
43
|
+
m_t = torch.add(m, vmul_m)
|
|
44
|
+
|
|
45
|
+
return m_t
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _output_v_compute(v, beta2, grad):
|
|
49
|
+
"""
|
|
50
|
+
_output_v_compute
|
|
51
|
+
do compute v_t = v + (1 - beta2)*(grad*grad -v)
|
|
52
|
+
"""
|
|
53
|
+
input_dtype = v.dtype
|
|
54
|
+
|
|
55
|
+
sneg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
56
|
+
|
|
57
|
+
# `formula; broadcast beta2 to vector`
|
|
58
|
+
beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
|
|
59
|
+
beta2_broad = beta2_tensor.expand_as(v)
|
|
60
|
+
|
|
61
|
+
# `formula; beta2 - 1`
|
|
62
|
+
vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
|
|
63
|
+
vsub_beta2_1 = vsub_beta2_1.to(v.device)
|
|
64
|
+
|
|
65
|
+
# `formula; grad * grad`
|
|
66
|
+
vmul_grad_grad = torch.mul(grad, grad)
|
|
67
|
+
|
|
68
|
+
# `formula; (v - grad*grad)`
|
|
69
|
+
vsub_v_grad = torch.sub(v, vmul_grad_grad)
|
|
70
|
+
|
|
71
|
+
# `formula; (beta2 -1) * (v - grad * grad)`
|
|
72
|
+
vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
|
|
73
|
+
|
|
74
|
+
# `formula; v_t = v + (beta2 - 1) * (v - grad * grad)`
|
|
75
|
+
v_t = torch.add(v, vmul_grad)
|
|
76
|
+
|
|
77
|
+
return v_t
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
|
|
81
|
+
"""
|
|
82
|
+
_inner_lr_compute
|
|
83
|
+
`formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)`
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
input_dtype = compute_shape_tensor.dtype
|
|
87
|
+
|
|
88
|
+
s_one = torch.ones((1), dtype=input_dtype)
|
|
89
|
+
|
|
90
|
+
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
91
|
+
|
|
92
|
+
# `formula; (1 - beta2_power)`
|
|
93
|
+
v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
|
|
94
|
+
v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
|
|
95
|
+
|
|
96
|
+
# `formula; sqrt(1 - beta2_power)`
|
|
97
|
+
v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
|
|
98
|
+
|
|
99
|
+
# `formula; (1 - beta1_power)`
|
|
100
|
+
v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
|
|
101
|
+
v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
|
|
102
|
+
|
|
103
|
+
# `formula; learning_rate * (sqrt(1-beta2_power)`
|
|
104
|
+
res = torch.mul(lr, v_sqrt_beta2_power)
|
|
105
|
+
|
|
106
|
+
# `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)`
|
|
107
|
+
res = torch.div(res, v_add_beta1_power)
|
|
108
|
+
return res.expand_as(compute_shape_tensor)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
|
|
112
|
+
"""
|
|
113
|
+
(epsilon + sqrt(v_t) )
|
|
114
|
+
"""
|
|
115
|
+
# `formula; sqrt(v_t)`
|
|
116
|
+
sqrt_vt = torch.sqrt(v_t)
|
|
117
|
+
|
|
118
|
+
# `formula; broadcast epsilon to vector`
|
|
119
|
+
input_dtype = v_t.dtype
|
|
120
|
+
epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
|
|
121
|
+
epsilon_broad = epsilon_tensor.expand_as(v_t)
|
|
122
|
+
epsilon_broad = epsilon_broad.to(sqrt_vt.device)
|
|
123
|
+
|
|
124
|
+
# `formula; epsilon + sqrt(v_t)`
|
|
125
|
+
v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
|
|
126
|
+
|
|
127
|
+
return v_add_sqrt_v
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _output_var_t_compute_use_nesterov(varparams):
|
|
131
|
+
"""
|
|
132
|
+
_output_var_t_compute_use_nesterov
|
|
133
|
+
`formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
|
|
134
|
+
`formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
|
|
135
|
+
"""
|
|
136
|
+
var = varparams.var
|
|
137
|
+
lr_t = varparams.lr_t
|
|
138
|
+
m_t = varparams.m_t
|
|
139
|
+
beta1_broad = varparams.beta1_broad
|
|
140
|
+
grad = varparams.grad
|
|
141
|
+
epsilon = varparams.epsilon
|
|
142
|
+
v_t = varparams.v_t
|
|
143
|
+
|
|
144
|
+
input_dtype = var.dtype
|
|
145
|
+
|
|
146
|
+
s_one = torch.ones((1), dtype=input_dtype)
|
|
147
|
+
|
|
148
|
+
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
149
|
+
|
|
150
|
+
# `formula; m_t * beta1`
|
|
151
|
+
v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
|
|
152
|
+
|
|
153
|
+
# `formula; 1 -beta1`
|
|
154
|
+
v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
|
|
155
|
+
vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
|
|
156
|
+
|
|
157
|
+
# `formula; (1-beta1)* grad`
|
|
158
|
+
v_mul_grad = torch.mul(vsub_1_beta1, grad)
|
|
159
|
+
|
|
160
|
+
# `formula; (m_t*beta1 + (1 - beta1)*grad)`
|
|
161
|
+
v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
|
|
162
|
+
|
|
163
|
+
# `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)`
|
|
164
|
+
# broadcast lr_t to vector
|
|
165
|
+
|
|
166
|
+
lrt_broad = lr_t.expand_as(var)
|
|
167
|
+
v_mul_left = torch.mul(lrt_broad, v_div_left)
|
|
168
|
+
|
|
169
|
+
# `formula; (epsilon + sqrt(v_t))`
|
|
170
|
+
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
|
|
171
|
+
|
|
172
|
+
# `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))`
|
|
173
|
+
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
|
|
174
|
+
|
|
175
|
+
# `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))`
|
|
176
|
+
v_t = torch.sub(var, v_div_res)
|
|
177
|
+
|
|
178
|
+
return v_t
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
|
|
182
|
+
"""
|
|
183
|
+
_output_var_t_compute
|
|
184
|
+
`var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
|
|
185
|
+
"""
|
|
186
|
+
# `formula; lr_t * m_t`
|
|
187
|
+
lr_t = lr_t.to(m_t.device)
|
|
188
|
+
v_mul_left = torch.mul(lr_t, m_t)
|
|
189
|
+
|
|
190
|
+
# `formula; (epsilon + sqrt(v_t))`
|
|
191
|
+
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
|
|
192
|
+
|
|
193
|
+
# `formula; lr_t * m_t /(epsilon + sqrt(v_t))`
|
|
194
|
+
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
|
|
195
|
+
|
|
196
|
+
# `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))`
|
|
197
|
+
v_t = torch.sub(var, v_div_res)
|
|
198
|
+
|
|
199
|
+
return v_t
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out):
|
|
203
|
+
var, m, v = out
|
|
204
|
+
input_dtype = m.dtype
|
|
205
|
+
beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device)
|
|
206
|
+
beta1_broad = beta1_tensor.expand_as(m)
|
|
207
|
+
m_t = _output_m_compute(m, beta1_broad, grad)
|
|
208
|
+
v_t = _output_v_compute(v, beta2, grad)
|
|
209
|
+
lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad)
|
|
210
|
+
if use_nesterov:
|
|
211
|
+
var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t)
|
|
212
|
+
var_t = _output_var_t_compute_use_nesterov(var_params)
|
|
213
|
+
else:
|
|
214
|
+
var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t)
|
|
215
|
+
return var_t, m_t, v_t
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def npu_group_norm_silu(x, gama, beta, group, eps):
|
|
20
|
+
if len(x.shape) != 4:
|
|
21
|
+
raise ValueError("x shape should be (N, C, H, W)")
|
|
22
|
+
res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
|
|
23
|
+
res = list(res)
|
|
24
|
+
if not res:
|
|
25
|
+
raise ValueError("run native_group_norm failed")
|
|
26
|
+
res[0] = torch.nn.functional.silu(res[0])
|
|
27
|
+
return res
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def npu_mish(x):
|
|
20
|
+
mish = torch.nn.Mish()
|
|
21
|
+
return mish(x)
|