mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -14,8 +14,10 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, Optional
|
|
17
19
|
from tqdm import tqdm
|
|
18
|
-
|
|
20
|
+
import numpy as np
|
|
19
21
|
from msprobe.core.common.const import Const, CompareConst
|
|
20
22
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
|
|
21
23
|
from msprobe.core.common.utils import add_time_as_suffix
|
|
@@ -28,6 +30,9 @@ from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_jso
|
|
|
28
30
|
from msprobe.mindspore.common.const import MsCompareConst
|
|
29
31
|
from msprobe.mindspore.common.log import logger
|
|
30
32
|
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
33
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
34
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
35
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
31
36
|
|
|
32
37
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
33
38
|
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
@@ -59,13 +64,129 @@ class ProcessResultPacket:
|
|
|
59
64
|
self.err_msg = err_msg
|
|
60
65
|
|
|
61
66
|
|
|
67
|
+
@dataclass
|
|
68
|
+
class Config:
|
|
69
|
+
execution_mode: str
|
|
70
|
+
dump_path: str
|
|
71
|
+
task: str
|
|
72
|
+
level: str
|
|
73
|
+
scope: Optional[Any]
|
|
74
|
+
list: Optional[Any]
|
|
75
|
+
framework: str
|
|
76
|
+
data_mode: str
|
|
77
|
+
file_format: str
|
|
78
|
+
dump_tensor_data_dir: str
|
|
79
|
+
async_dump: bool
|
|
80
|
+
summary_mode: Optional[Any] = None
|
|
81
|
+
|
|
82
|
+
|
|
62
83
|
class ApiAccuracyChecker:
|
|
63
84
|
def __init__(self, args):
|
|
64
85
|
self.api_infos = dict()
|
|
65
86
|
self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
|
|
87
|
+
self.save_error_data = args.save_error_data
|
|
88
|
+
if self.save_error_data:
|
|
89
|
+
config, dump_path_aggregation = self.init_save_error_data(args)
|
|
90
|
+
self.data_collector = build_data_collector(config)
|
|
91
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
66
92
|
|
|
67
93
|
@staticmethod
|
|
68
|
-
def
|
|
94
|
+
def init_save_error_data(args):
|
|
95
|
+
config = Config(
|
|
96
|
+
execution_mode="pynative",
|
|
97
|
+
dump_path=f"{args.out_path}",
|
|
98
|
+
dump_tensor_data_dir=f"{args.out_path}",
|
|
99
|
+
task="tensor", # 任务类型,模拟保存tensor数据
|
|
100
|
+
level="L1", # 级别
|
|
101
|
+
scope=None, # 作用域 (None)
|
|
102
|
+
list=None, # API 列表 (None)
|
|
103
|
+
framework=Const.MS_FRAMEWORK, # 框架类型
|
|
104
|
+
data_mode="all",
|
|
105
|
+
file_format="npy",
|
|
106
|
+
async_dump=False
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
dump_dir = f"{args.out_path}"
|
|
110
|
+
dump_data_dir = os.path.join(dump_dir, "error_data")
|
|
111
|
+
create_directory(dump_data_dir)
|
|
112
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
113
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
114
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
115
|
+
dump_path_aggregation.dump_error_info_path = os.path.join(dump_dir, "dump_error_info.log")
|
|
116
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
117
|
+
return config, dump_path_aggregation
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
121
|
+
"""
|
|
122
|
+
Args:
|
|
123
|
+
api_info: ApiInfo
|
|
124
|
+
forward_or_backward: str
|
|
125
|
+
Returns:
|
|
126
|
+
ApiInputAggregation
|
|
127
|
+
"""
|
|
128
|
+
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
129
|
+
kwargs = api_info.get_kwargs()
|
|
130
|
+
if forward_or_backward == Const.FORWARD:
|
|
131
|
+
gradient_inputs = None
|
|
132
|
+
else:
|
|
133
|
+
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
134
|
+
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def is_api_checkable(api_name_str):
|
|
138
|
+
'''
|
|
139
|
+
Args:
|
|
140
|
+
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
141
|
+
Returns:
|
|
142
|
+
is_checkable: bool
|
|
143
|
+
Description:
|
|
144
|
+
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
145
|
+
'''
|
|
146
|
+
api_name_str_list = api_name_str.split(Const.SEP)
|
|
147
|
+
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
148
|
+
return False
|
|
149
|
+
api_type_str = api_name_str_list[0]
|
|
150
|
+
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
151
|
+
api_list = load_yaml(yaml_path)
|
|
152
|
+
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
153
|
+
supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
|
|
154
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
155
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
156
|
+
return True
|
|
157
|
+
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
158
|
+
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
159
|
+
return True
|
|
160
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
161
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
162
|
+
return True
|
|
163
|
+
if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
|
|
164
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
165
|
+
return True
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
def post_forward_hook(self, api_or_module_name, primitive_instance, args, kwargs, output):
|
|
169
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
170
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
171
|
+
self.data_collector.forward_data_collect_only_tensor(
|
|
172
|
+
api_or_module_name,
|
|
173
|
+
primitive_instance,
|
|
174
|
+
os.getpid(),
|
|
175
|
+
module_input_output
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def backward_hook(self, api_or_module_name, module, grad_input, grad_output):
|
|
179
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
180
|
+
|
|
181
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
182
|
+
self.data_collector.backward_data_collect_only_tensor(
|
|
183
|
+
api_or_module_name,
|
|
184
|
+
module,
|
|
185
|
+
os.getpid(),
|
|
186
|
+
module_input_output
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def run_and_compare_helper(self, api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
69
190
|
"""
|
|
70
191
|
Args:
|
|
71
192
|
api_info: ApiInfo
|
|
@@ -83,13 +204,22 @@ class ApiAccuracyChecker:
|
|
|
83
204
|
"""
|
|
84
205
|
# get output
|
|
85
206
|
if global_context.get_is_constructed():
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
207
|
+
if forward_or_backward == Const.FORWARD:
|
|
208
|
+
tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
209
|
+
forward_or_backward,
|
|
210
|
+
global_context.get_framework())
|
|
211
|
+
elif forward_or_backward == Const.BACKWARD:
|
|
212
|
+
tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
213
|
+
forward_or_backward,
|
|
214
|
+
global_context.get_framework())
|
|
215
|
+
else:
|
|
216
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
217
|
+
forward_or_backward, global_context.get_framework())
|
|
89
218
|
else:
|
|
90
219
|
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
91
220
|
|
|
92
221
|
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
222
|
+
|
|
93
223
|
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
94
224
|
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
95
225
|
if len(tested_outputs) != len(bench_outputs):
|
|
@@ -114,64 +244,26 @@ class ApiAccuracyChecker:
|
|
|
114
244
|
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
115
245
|
status = CompareConst.PASS
|
|
116
246
|
err_msg = ""
|
|
247
|
+
|
|
117
248
|
else:
|
|
118
249
|
status = CompareConst.ERROR
|
|
119
250
|
err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
120
251
|
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
|
|
252
|
+
if forward_or_backward == Const.FORWARD and self.save_error_data \
|
|
253
|
+
and global_context.get_is_constructed():
|
|
254
|
+
api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.FORWARD}"
|
|
255
|
+
self.post_forward_hook(api_name_str_backward, None, inputs, kwargs, forward_result_tuple)
|
|
256
|
+
|
|
257
|
+
if forward_or_backward == Const.BACKWARD and self.save_error_data \
|
|
258
|
+
and global_context.get_is_constructed():
|
|
259
|
+
api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.BACKWARD}"
|
|
260
|
+
self.backward_hook(api_name_str_backward, None, gradient_inputs, backward_result_tuple)
|
|
261
|
+
|
|
121
262
|
basic_info_status = \
|
|
122
263
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
123
264
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
124
265
|
return output_list
|
|
125
266
|
|
|
126
|
-
@staticmethod
|
|
127
|
-
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
128
|
-
"""
|
|
129
|
-
Args:
|
|
130
|
-
api_info: ApiInfo
|
|
131
|
-
forward_or_backward: str
|
|
132
|
-
Returns:
|
|
133
|
-
ApiInputAggregation
|
|
134
|
-
"""
|
|
135
|
-
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
136
|
-
kwargs = api_info.get_kwargs()
|
|
137
|
-
if forward_or_backward == Const.FORWARD:
|
|
138
|
-
gradient_inputs = None
|
|
139
|
-
else:
|
|
140
|
-
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
141
|
-
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
142
|
-
|
|
143
|
-
@staticmethod
|
|
144
|
-
def is_api_checkable(api_name_str):
|
|
145
|
-
'''
|
|
146
|
-
Args:
|
|
147
|
-
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
148
|
-
Returns:
|
|
149
|
-
is_checkable: bool
|
|
150
|
-
Description:
|
|
151
|
-
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
152
|
-
'''
|
|
153
|
-
api_name_str_list = api_name_str.split(Const.SEP)
|
|
154
|
-
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
155
|
-
return False
|
|
156
|
-
api_type_str = api_name_str_list[0]
|
|
157
|
-
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
158
|
-
api_list = load_yaml(yaml_path)
|
|
159
|
-
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
160
|
-
supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
|
|
161
|
-
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
162
|
-
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
163
|
-
return True
|
|
164
|
-
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
165
|
-
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
166
|
-
return True
|
|
167
|
-
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
168
|
-
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
169
|
-
return True
|
|
170
|
-
if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
|
|
171
|
-
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
172
|
-
return True
|
|
173
|
-
return False
|
|
174
|
-
|
|
175
267
|
def parse(self, api_info_path):
|
|
176
268
|
|
|
177
269
|
api_info_dict = load_json(api_info_path)
|
|
@@ -183,9 +275,9 @@ class ApiAccuracyChecker:
|
|
|
183
275
|
MsCompareConst.TENSOR_TASK))
|
|
184
276
|
try:
|
|
185
277
|
framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
278
|
+
"framework field in api_info.json", accepted_type=str,
|
|
279
|
+
accepted_value=(Const.MS_FRAMEWORK,
|
|
280
|
+
Const.MT_FRAMEWORK))
|
|
189
281
|
except Exception as e:
|
|
190
282
|
framework = Const.MS_FRAMEWORK
|
|
191
283
|
logger.warning(f"JSON parsing error in framework field: {e}")
|
|
@@ -301,4 +393,4 @@ class ApiAccuracyChecker:
|
|
|
301
393
|
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
302
394
|
self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
|
|
303
395
|
|
|
304
|
-
self.data_manager.save_results(api_name_str)
|
|
396
|
+
self.data_manager.save_results(api_name_str)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import os
|
|
17
|
+
import numpy as np
|
|
16
18
|
import mindspore
|
|
17
19
|
from mindspore import ops
|
|
18
20
|
from msprobe.core.common.const import Const
|
|
@@ -38,7 +40,6 @@ else:
|
|
|
38
40
|
import torch
|
|
39
41
|
|
|
40
42
|
|
|
41
|
-
|
|
42
43
|
class ApiInputAggregation:
|
|
43
44
|
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
44
45
|
"""
|
|
@@ -148,13 +149,13 @@ class ApiRunner:
|
|
|
148
149
|
Args:
|
|
149
150
|
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"]
|
|
150
151
|
api_sub_name: str, e.g. "relu"
|
|
151
|
-
api_platform: str: Union["
|
|
152
|
+
api_platform: str: Union["mindspore", "pytorch"]
|
|
152
153
|
|
|
153
154
|
Return:
|
|
154
155
|
api_instance: function object
|
|
155
156
|
|
|
156
157
|
Description:
|
|
157
|
-
get mindspore.mint/torch api
|
|
158
|
+
get mindspore.mint/torch api function
|
|
158
159
|
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
159
160
|
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
160
161
|
"""
|
|
@@ -189,6 +190,8 @@ class ApiRunner:
|
|
|
189
190
|
forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
|
|
190
191
|
forward_result_tuple = convert_to_tuple(forward_result)
|
|
191
192
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
|
|
193
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
194
|
+
return res_compute_element_list, inputs, kwargs, forward_result_tuple
|
|
192
195
|
else:
|
|
193
196
|
if gradient_inputs is None:
|
|
194
197
|
err_msg = f"ApiRunner.run_api failed: run backward api but gradient_inputs is missing"
|
|
@@ -206,6 +209,7 @@ class ApiRunner:
|
|
|
206
209
|
backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
|
|
207
210
|
backward_result_tuple = convert_to_tuple(backward_result)
|
|
208
211
|
res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
|
|
212
|
+
return res_compute_element_list, gradient_inputs, backward_result_tuple
|
|
209
213
|
else:
|
|
210
214
|
# set requires_grad
|
|
211
215
|
requires_grad_index = []
|
|
@@ -52,8 +52,14 @@ def softmax_grad(dp, softmax_res):
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
55
|
+
# 检查维度
|
|
56
|
+
if kv_tensor.dim() != 4:
|
|
57
|
+
raise ValueError(f"broadcast_kv: kv_tensor 必须是 4 维 (B, N_kv, S, D),但得到 {kv_tensor.shape}")
|
|
55
58
|
if num_kv_heads == 0 or num_kv_heads > num_heads:
|
|
56
|
-
raise ValueError(
|
|
59
|
+
raise ValueError("broadcast_kv: num_kv_heads 必须大于 0 且不超过 num_heads。")
|
|
60
|
+
if num_heads % num_kv_heads != 0:
|
|
61
|
+
raise ValueError(f"broadcast_kv: num_heads({num_heads}) 必须能被 num_kv_heads({num_kv_heads}) 整除。")
|
|
62
|
+
|
|
57
63
|
|
|
58
64
|
factor = num_heads // num_kv_heads
|
|
59
65
|
kv_shape = kv_tensor.shape
|
|
@@ -68,6 +74,13 @@ def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
|
|
|
68
74
|
|
|
69
75
|
|
|
70
76
|
def calculate_qk(q, k, attn_mask, pse, scalar_value):
|
|
77
|
+
# 基本形状检查
|
|
78
|
+
if q.dim() < 4 or k.dim() < 4:
|
|
79
|
+
raise ValueError(f"calculate_qk: q,k 必须至少 4 维,q={q.dim()},k={k.dim()}")
|
|
80
|
+
# 检查 head_dim 一致性
|
|
81
|
+
if q.size(-1) != k.size(-1):
|
|
82
|
+
raise ValueError(f"calculate_qk: q.head_dim({q.size(-1)}) != k.head_dim({k.size(-1)})")
|
|
83
|
+
|
|
71
84
|
if k.dim() != 4:
|
|
72
85
|
raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})")
|
|
73
86
|
|
|
@@ -95,6 +108,10 @@ def fusion_attention_forward(forward_params):
|
|
|
95
108
|
scalar_value = forward_params.scalar_value
|
|
96
109
|
keep_prob = forward_params.keep_prob
|
|
97
110
|
|
|
111
|
+
# 拦截 keep_prob 为 0 的情况,防止除零
|
|
112
|
+
if keep_prob == 0:
|
|
113
|
+
raise ValueError("fusion_attention_forward: keep_prob 不能为 0,避免除零错误。")
|
|
114
|
+
|
|
98
115
|
qk = calculate_qk(q, k, attn_mask, pse, scalar_value)
|
|
99
116
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
100
117
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -115,6 +132,11 @@ def fusion_attention_backward(backward_params):
|
|
|
115
132
|
pse = backward_params.pse
|
|
116
133
|
scalar_value = backward_params.scalar_value
|
|
117
134
|
keep_prob = backward_params.keep_prob
|
|
135
|
+
|
|
136
|
+
# 拦截 keep_prob 为 0 的情况,防止除零
|
|
137
|
+
if keep_prob == 0:
|
|
138
|
+
raise ValueError("fusion_attention_backward: keep_prob 不能为 0,避免除零错误。")
|
|
139
|
+
|
|
118
140
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
119
141
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
120
142
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -138,34 +160,45 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
138
160
|
|
|
139
161
|
if input_layout == "TND":
|
|
140
162
|
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
163
|
+
|
|
164
|
+
# 防止 head_num 为 0
|
|
165
|
+
if n1 == 0:
|
|
166
|
+
raise ValueError("parse_bsnd_args: head_num (n1) 不能为 0,避免除零错误。")
|
|
167
|
+
|
|
141
168
|
try:
|
|
142
169
|
if input_layout == "BSH":
|
|
143
170
|
b, s1, h1 = query.shape
|
|
144
171
|
_, s2, h2 = key.shape
|
|
145
172
|
d = h1 // n1
|
|
173
|
+
# 拦截 d 为 0 的情况
|
|
174
|
+
if d == 0:
|
|
175
|
+
raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
|
|
146
176
|
n2 = h2 // d
|
|
147
177
|
elif input_layout == "SBH":
|
|
148
178
|
s1, b, h1 = query.shape
|
|
149
179
|
s2, _, h2 = key.shape
|
|
150
180
|
d = h1 // n1
|
|
181
|
+
if d == 0:
|
|
182
|
+
raise ValueError("parse_bsnd_args: 计算得到的 head_dim d 不能为 0。")
|
|
151
183
|
n2 = h2 // d
|
|
152
184
|
elif input_layout == "BSND":
|
|
153
185
|
b, s1, n1, d = query.shape
|
|
154
186
|
_, s2, n2, _ = key.shape
|
|
187
|
+
if d == 0:
|
|
188
|
+
raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
|
|
155
189
|
h1 = n1 * d
|
|
156
190
|
h2 = n2 * d
|
|
157
191
|
elif input_layout == "BNSD":
|
|
158
192
|
b, n1, s1, d = query.shape
|
|
159
193
|
_, n2, s2, _ = key.shape
|
|
194
|
+
if d == 0:
|
|
195
|
+
raise ValueError("parse_bsnd_args: head_dim d 不能为 0。")
|
|
160
196
|
h1 = n1 * d
|
|
161
197
|
h2 = n2 * d
|
|
162
198
|
except Exception as e:
|
|
163
199
|
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
164
200
|
|
|
165
|
-
|
|
166
|
-
raise ValueError(f"Value d must be non-zero.")
|
|
167
|
-
_dtype = query.dtype
|
|
168
|
-
ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype)
|
|
201
|
+
ret = (b, s1, s2, n1, n2, d, h1, h2, query.dtype)
|
|
169
202
|
return ret
|
|
170
203
|
|
|
171
204
|
|
|
@@ -230,67 +263,6 @@ def convert_to_bnsd(_input, n, input_layout):
|
|
|
230
263
|
return out.to(GTYPE)
|
|
231
264
|
|
|
232
265
|
|
|
233
|
-
def convert_from_bsnd(_input, input_layout):
|
|
234
|
-
"""
|
|
235
|
-
transform qkv from bsnd to input_layout.
|
|
236
|
-
B: batch_size
|
|
237
|
-
S: sequence_length
|
|
238
|
-
N: num_heads
|
|
239
|
-
D: head_dim
|
|
240
|
-
Args:
|
|
241
|
-
_input (torch.Tensor): tensor of shape (B,S,N,D)
|
|
242
|
-
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
243
|
-
Returns:
|
|
244
|
-
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
245
|
-
"""
|
|
246
|
-
if input_layout == "BSH":
|
|
247
|
-
# (B,S,N,D)=>(B,S,N*D)
|
|
248
|
-
out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
|
|
249
|
-
elif input_layout == "SBH":
|
|
250
|
-
# (B,S,N,D)=>(S,B,N*D)
|
|
251
|
-
out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
|
|
252
|
-
elif input_layout == "BNSD":
|
|
253
|
-
# (B,S,N,D)=>(B,N,S,D)
|
|
254
|
-
out = rearrange(_input, 'b s n d -> b n s d').contiguous()
|
|
255
|
-
elif input_layout == "TND":
|
|
256
|
-
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
257
|
-
else:
|
|
258
|
-
out = _input
|
|
259
|
-
return out
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
def convert_to_bsnd(_input, n, input_layout):
|
|
263
|
-
"""
|
|
264
|
-
transform qkv from input_layout to bsnd.
|
|
265
|
-
B: batch_size
|
|
266
|
-
S: sequence_length
|
|
267
|
-
N: num_heads
|
|
268
|
-
D: head_dim
|
|
269
|
-
Args:
|
|
270
|
-
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
271
|
-
n (int): num_heads
|
|
272
|
-
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
273
|
-
Returns:
|
|
274
|
-
tensor of shape (B,S,N,D)
|
|
275
|
-
"""
|
|
276
|
-
if input_layout == "BSH":
|
|
277
|
-
# (B,S,N*D)=>(B,S,N,D)
|
|
278
|
-
out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
|
|
279
|
-
elif input_layout == "SBH":
|
|
280
|
-
# (S,B,N*D)=>(B,S,N,D)
|
|
281
|
-
out = rearrange(_input, 's b (n d) -> b s n d', n=n)
|
|
282
|
-
elif input_layout == "BNSD":
|
|
283
|
-
# (B,N,S,D)=>(B,S,N,D)
|
|
284
|
-
out = rearrange(_input, 'b n s d -> b s n d', n=n)
|
|
285
|
-
elif input_layout == "TND":
|
|
286
|
-
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
287
|
-
else:
|
|
288
|
-
out = _input
|
|
289
|
-
if out.dim() != 4:
|
|
290
|
-
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
291
|
-
return out
|
|
292
|
-
|
|
293
|
-
|
|
294
266
|
def generate_attn_mask(*args):
|
|
295
267
|
"""
|
|
296
268
|
# 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
|
|
@@ -417,17 +389,20 @@ def get_input_layout(*args, **kwargs):
|
|
|
417
389
|
|
|
418
390
|
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
419
391
|
if len(args) < 2:
|
|
420
|
-
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
|
|
392
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should be greater than or equal to 2.")
|
|
421
393
|
|
|
422
394
|
# query, key, value, head_num, input_layout
|
|
423
395
|
head_num = get_head_num(*args, **kwargs)
|
|
424
396
|
input_layout = get_input_layout(*args, **kwargs)
|
|
425
397
|
|
|
426
398
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
399
|
+
# 此处 d 已在 parse_bsnd_args 中检查为非零
|
|
427
400
|
if n1 == n2 and s1 == s2:
|
|
428
401
|
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
429
402
|
else:
|
|
430
403
|
logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
404
|
+
if n2 == 0:
|
|
405
|
+
raise ValueError("n2 不能为 0,避免除零错误。")
|
|
431
406
|
if not (n1 % n2 == 0 and n1 >= n2):
|
|
432
407
|
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
433
408
|
|
|
@@ -436,7 +411,7 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
|
436
411
|
"d": d, "h1": h1, "h2": h2, "dtype": dtype
|
|
437
412
|
}
|
|
438
413
|
new_kwargs = {
|
|
439
|
-
"keep_prob": 1,
|
|
414
|
+
"keep_prob": 1, # 注意:如果外部传入 keep_prob 为 0,也会在 fusion_attention_forward 中捕获
|
|
440
415
|
"scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)),
|
|
441
416
|
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
442
417
|
"prefix": kwargs.get("prefix"),
|
|
@@ -455,10 +430,13 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
455
430
|
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
456
431
|
|
|
457
432
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
433
|
+
# 此处 d 已在 parse_bsnd_args 中检查为非零
|
|
458
434
|
if n1 == n2 and s1 == s2:
|
|
459
435
|
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
460
436
|
else:
|
|
461
437
|
logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
438
|
+
if n2 == 0:
|
|
439
|
+
raise ValueError("n2 不能为 0,避免除零错误。")
|
|
462
440
|
if not (n1 % n2 == 0 and n1 >= n2):
|
|
463
441
|
raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.")
|
|
464
442
|
|
|
@@ -468,7 +446,7 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
468
446
|
}
|
|
469
447
|
|
|
470
448
|
new_kwargs = {
|
|
471
|
-
"keep_prob": 1,
|
|
449
|
+
"keep_prob": 1, # 同上,fusion_attention_backward 内会拦截 keep_prob 为 0 的情况
|
|
472
450
|
"scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)),
|
|
473
451
|
"sparse_mode": kwargs.get("sparse_mode", 0),
|
|
474
452
|
"prefix": kwargs.get("prefix"),
|
|
@@ -39,6 +39,8 @@ def add_api_accuracy_checker_argument(parser):
|
|
|
39
39
|
help="<optional> The ut task result out path.")
|
|
40
40
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
41
|
help="<optional> the exit csv for continue")
|
|
42
|
+
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
43
|
+
help="<optional> Save compare failed api output.", required=False)
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
def multi_add_api_accuracy_checker_argument(parser):
|
|
@@ -49,6 +51,8 @@ def multi_add_api_accuracy_checker_argument(parser):
|
|
|
49
51
|
help="<optional> The ut task result out path.")
|
|
50
52
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
51
53
|
help="<optional> the exit csv for continue")
|
|
54
|
+
parser.add_argument('-save_error_data', dest="save_error_data", action="store_true",
|
|
55
|
+
help="<optional> Save compare failed api output.", required=False)
|
|
52
56
|
#以下属于多线程参数
|
|
53
57
|
parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
|
|
54
58
|
help="<optional> set device id to run ut, must be unique and in range 0-7",
|
|
@@ -188,7 +188,7 @@ class DataManager:
|
|
|
188
188
|
|
|
189
189
|
def record_exception_skip(self, api_name, forward_or_backward, err_msg):
|
|
190
190
|
'''
|
|
191
|
-
record exception_skip
|
|
191
|
+
record exception_skip information into self.record_exception_skip.
|
|
192
192
|
self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
|
|
193
193
|
string in key is api_name, string in value is err_msg
|
|
194
194
|
'''
|
|
@@ -270,7 +270,7 @@ class DataManager:
|
|
|
270
270
|
entry.backward_pass_status,
|
|
271
271
|
overall_err_msg
|
|
272
272
|
]
|
|
273
|
-
# change row if this api has
|
|
273
|
+
# change row if this api has exception_skip information
|
|
274
274
|
if api_name in self.results_exception_skip:
|
|
275
275
|
if self.results_exception_skip[api_name][Const.FORWARD] is not None:
|
|
276
276
|
row[1] = CompareConst.SKIP
|