mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
|
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
52
|
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
56
54
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
|
|
57
55
|
ExecParams
|
|
58
56
|
|
|
@@ -90,27 +88,22 @@ seed_all()
|
|
|
90
88
|
|
|
91
89
|
def run_ut(config):
|
|
92
90
|
logger.info("start UT test")
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
else:
|
|
97
|
-
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
98
|
-
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
91
|
+
|
|
92
|
+
logger.info(f"UT task result will be saved in {config.result_csv_path}")
|
|
93
|
+
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
99
94
|
|
|
100
95
|
if config.save_error_data:
|
|
101
96
|
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
102
97
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
103
98
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
api_name_set = set()
|
|
113
|
-
run_api_offline(config, compare, api_name_set)
|
|
99
|
+
|
|
100
|
+
csv_df = read_csv(config.result_csv_path)
|
|
101
|
+
try:
|
|
102
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
103
|
+
except IndexError:
|
|
104
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
105
|
+
api_name_set = set()
|
|
106
|
+
run_api_offline(config, compare, api_name_set)
|
|
114
107
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
115
108
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
116
109
|
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
164
157
|
gc.collect()
|
|
165
158
|
|
|
166
159
|
|
|
167
|
-
def run_api_online(config, compare):
|
|
168
|
-
attl = init_attl(config.online_config)
|
|
169
|
-
dispatcher = ConsumerDispatcher(compare=compare)
|
|
170
|
-
dispatcher.start(handle_func=run_torch_api_online, config=config)
|
|
171
|
-
|
|
172
|
-
def tcp_communication_flow():
|
|
173
|
-
while True:
|
|
174
|
-
api_data = attl.recv()
|
|
175
|
-
if api_data == 'STOP_':
|
|
176
|
-
continue
|
|
177
|
-
if api_data == 'KILL_':
|
|
178
|
-
time.sleep(1)
|
|
179
|
-
logger.info("==========接收到STOP信号==========")
|
|
180
|
-
dispatcher.stop()
|
|
181
|
-
attl.stop_serve()
|
|
182
|
-
time.sleep(1)
|
|
183
|
-
break
|
|
184
|
-
if not isinstance(api_data, ApiData):
|
|
185
|
-
continue
|
|
186
|
-
api_full_name = api_data.name
|
|
187
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
188
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
189
|
-
continue
|
|
190
|
-
if api_data.rank in config.online_config.rank_list:
|
|
191
|
-
dispatcher.update_consume_queue(api_data)
|
|
192
|
-
|
|
193
|
-
def shared_storage_communication_flow():
|
|
194
|
-
flag_num = -1
|
|
195
|
-
while True:
|
|
196
|
-
api_data = attl.download()
|
|
197
|
-
if api_data == "start":
|
|
198
|
-
if flag_num == -1:
|
|
199
|
-
flag_num += 1
|
|
200
|
-
flag_num += 1
|
|
201
|
-
if api_data == "end":
|
|
202
|
-
flag_num -= 1
|
|
203
|
-
if flag_num == 0:
|
|
204
|
-
dispatcher.stop()
|
|
205
|
-
break
|
|
206
|
-
if not isinstance(api_data, ApiData):
|
|
207
|
-
continue
|
|
208
|
-
api_full_name = api_data.name
|
|
209
|
-
_, api_name = extract_basic_api_segments(api_full_name)
|
|
210
|
-
if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
|
|
211
|
-
continue
|
|
212
|
-
if api_data.rank in config.online_config.rank_list:
|
|
213
|
-
dispatcher.update_consume_queue(api_data)
|
|
214
|
-
|
|
215
|
-
if config.online_config.nfs_path:
|
|
216
|
-
shared_storage_communication_flow()
|
|
217
|
-
else:
|
|
218
|
-
tcp_communication_flow()
|
|
219
|
-
|
|
220
|
-
|
|
221
160
|
def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
222
161
|
"""
|
|
223
162
|
run api(api_name) if api_name not in black_list and in white_list.
|
|
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
315
254
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
316
255
|
|
|
317
256
|
|
|
318
|
-
def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
319
|
-
in_fwd_data_list = []
|
|
320
|
-
api_type, api_name = extract_basic_api_segments(api_full_name)
|
|
321
|
-
args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
|
|
322
|
-
in_fwd_data_list.append(args)
|
|
323
|
-
in_fwd_data_list.append(kwargs)
|
|
324
|
-
if kwargs.get("device"):
|
|
325
|
-
del kwargs["device"]
|
|
326
|
-
|
|
327
|
-
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
328
|
-
device_out = exec_api(device_exec_params)
|
|
329
|
-
device_out = move2device_exec(device_out, "cpu")
|
|
330
|
-
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
331
|
-
|
|
332
|
-
|
|
333
257
|
def check_need_grad(api_info_dict):
|
|
334
258
|
need_grad = True
|
|
335
259
|
if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
|
|
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
|
|
|
389
313
|
return error_data_path
|
|
390
314
|
|
|
391
315
|
|
|
392
|
-
def init_attl(config):
|
|
393
|
-
"""config: OnlineConfig"""
|
|
394
|
-
attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
|
|
395
|
-
connect_ip=config.host,
|
|
396
|
-
connect_port=config.port,
|
|
397
|
-
nfs_path=config.nfs_path,
|
|
398
|
-
tls_path=config.tls_path))
|
|
399
|
-
return attl
|
|
400
|
-
|
|
401
|
-
|
|
402
316
|
def _run_ut_parser(parser):
|
|
403
317
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
|
|
404
318
|
help="<Optional> The api param tool result file: generate from api param tool, "
|
|
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
|
|
|
481
395
|
_run_ut_parser(parser)
|
|
482
396
|
args = parser.parse_args(sys.argv[1:])
|
|
483
397
|
run_ut_command(args)
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def checked_online_config(online_config):
|
|
487
|
-
if not online_config.is_online:
|
|
488
|
-
return
|
|
489
|
-
if not isinstance(online_config.is_online, bool):
|
|
490
|
-
raise ValueError("is_online must be bool type")
|
|
491
|
-
# rank_list
|
|
492
|
-
if not isinstance(online_config.rank_list, list):
|
|
493
|
-
raise ValueError("rank_list must be a list")
|
|
494
|
-
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
495
|
-
raise ValueError("All elements in rank_list must be integers")
|
|
496
|
-
|
|
497
|
-
# nfs_path
|
|
498
|
-
if online_config.nfs_path:
|
|
499
|
-
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
500
|
-
return
|
|
501
|
-
# tls_path
|
|
502
|
-
if online_config.tls_path:
|
|
503
|
-
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
504
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
505
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
506
|
-
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
-
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
-
if os.path.exists(crl_path):
|
|
509
|
-
check_file_or_directory_path(crl_path)
|
|
510
|
-
|
|
511
|
-
# host and port
|
|
512
|
-
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
513
|
-
raise Exception(f"host: {online_config.host} is invalid.")
|
|
514
|
-
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
515
|
-
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
516
398
|
|
|
517
399
|
|
|
518
400
|
def run_ut_command(args):
|
|
@@ -525,7 +407,7 @@ def run_ut_command(args):
|
|
|
525
407
|
else:
|
|
526
408
|
checker_config = CheckerConfig()
|
|
527
409
|
|
|
528
|
-
if not
|
|
410
|
+
if not args.api_info_file:
|
|
529
411
|
logger.error("Please provide api_info_file for offline run ut.")
|
|
530
412
|
raise Exception("Please provide api_info_file for offline run ut.")
|
|
531
413
|
|
|
@@ -588,8 +470,6 @@ def run_ut_command(args):
|
|
|
588
470
|
global UT_ERROR_DATA_DIR
|
|
589
471
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
590
472
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
591
|
-
online_config = checker_config.get_online_config()
|
|
592
|
-
checked_online_config(online_config)
|
|
593
473
|
config_params = {
|
|
594
474
|
'forward_content': forward_content,
|
|
595
475
|
'backward_content': backward_content,
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -150,7 +150,7 @@ def remove_dropout():
|
|
|
150
150
|
F.dropout3d = function_dropout3d
|
|
151
151
|
|
|
152
152
|
|
|
153
|
-
def seed_all(seed=1234, mode=False, rm_dropout=
|
|
153
|
+
def seed_all(seed=1234, mode=False, rm_dropout=False):
|
|
154
154
|
check_seed_all(seed, mode, rm_dropout)
|
|
155
155
|
try:
|
|
156
156
|
random.seed(seed)
|
|
@@ -388,26 +388,6 @@ def load_pkl(pt_path):
|
|
|
388
388
|
return pt
|
|
389
389
|
|
|
390
390
|
|
|
391
|
-
def save_api_data(api_data):
|
|
392
|
-
"""Save data to io stream"""
|
|
393
|
-
try:
|
|
394
|
-
io_buff = io.BytesIO()
|
|
395
|
-
torch.save(api_data, io_buff)
|
|
396
|
-
except Exception as e:
|
|
397
|
-
raise RuntimeError("save api_data to io_buff failed") from e
|
|
398
|
-
return io_buff
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
def load_api_data(api_data_bytes):
|
|
402
|
-
"""Load data from bytes stream"""
|
|
403
|
-
try:
|
|
404
|
-
buffer = io.BytesIO(api_data_bytes)
|
|
405
|
-
buffer = torch.load(buffer, map_location="cpu", weights_only=False)
|
|
406
|
-
except Exception as e:
|
|
407
|
-
raise RuntimeError("load api_data from bytes failed") from e
|
|
408
|
-
return buffer
|
|
409
|
-
|
|
410
|
-
|
|
411
391
|
def is_recomputation():
|
|
412
392
|
"""Check if the current operation is in the re-computation phase.
|
|
413
393
|
|
|
@@ -31,8 +31,16 @@ def compare(input_param, output_path, **kwargs):
|
|
|
31
31
|
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
32
32
|
config = setup_comparison(input_param, output_path, **kwargs)
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
config_dict = {
|
|
35
|
+
'stack_mode': config.stack_mode,
|
|
36
|
+
'auto_analyze': config.auto_analyze,
|
|
37
|
+
'fuzzy_match': config.fuzzy_match,
|
|
38
|
+
'highlight': config.highlight,
|
|
39
|
+
'dump_mode': config.dump_mode,
|
|
40
|
+
'first_diff_analyze': config.first_diff_analyze,
|
|
41
|
+
'compared_file_type': config.compared_file_type
|
|
42
|
+
}
|
|
43
|
+
mode_config = ModeConfig(**config_dict)
|
|
36
44
|
mapping_config = MappingConfig(data_mapping=config.data_mapping)
|
|
37
45
|
pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
|
|
38
46
|
pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
|
|
@@ -13,21 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
16
|
|
|
18
|
-
from msprobe.pytorch.
|
|
17
|
+
from msprobe.pytorch.compare.distributed_compare import compare_distributed
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
def
|
|
22
|
-
|
|
23
|
-
all_api_registered = api_register.all_api_registered
|
|
24
|
-
if all_api_registered:
|
|
25
|
-
api_register.restore_all_api()
|
|
26
|
-
result = original_script(*args, **kwargs)
|
|
27
|
-
if all_api_registered:
|
|
28
|
-
api_register.register_all_api()
|
|
29
|
-
return result
|
|
30
|
-
|
|
31
|
-
original_script = torch.jit.script
|
|
32
|
-
api_register = get_api_register()
|
|
33
|
-
torch.jit.script = patched_script
|
|
20
|
+
def pt_diff_analyze(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze):
|
|
21
|
+
compare_distributed(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze=first_diff_analyze)
|
msprobe/pytorch/compare/utils.py
CHANGED
|
@@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name):
|
|
|
35
35
|
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
36
36
|
except RuntimeError as e:
|
|
37
37
|
# 这里捕获 load_pt 中抛出的异常
|
|
38
|
-
|
|
38
|
+
data_path_file_name = os.path.basename(data_path)
|
|
39
|
+
logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
|
|
39
40
|
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
40
41
|
except AttributeError as e:
|
|
41
42
|
# 这里捕获 detach 方法抛出的异常
|
|
@@ -34,6 +34,7 @@ class DebuggerConfig:
|
|
|
34
34
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
35
35
|
self.framework = Const.PT_FRAMEWORK
|
|
36
36
|
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
37
|
+
self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW
|
|
37
38
|
|
|
38
39
|
if self.task == Const.FREE_BENCHMARK:
|
|
39
40
|
self.fuzz_device = task_config.fuzz_device
|
|
@@ -47,16 +48,6 @@ class DebuggerConfig:
|
|
|
47
48
|
"max_sample": task_config.max_sample
|
|
48
49
|
}
|
|
49
50
|
|
|
50
|
-
self.online_run_ut = False
|
|
51
|
-
if self.task == Const.TENSOR:
|
|
52
|
-
# dump api tensor and collaborate with online run_ut
|
|
53
|
-
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
54
|
-
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
55
|
-
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
56
|
-
self.host = task_config.host if task_config.host else ""
|
|
57
|
-
self.port = task_config.port if task_config.port else -1
|
|
58
|
-
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
59
|
-
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
60
51
|
|
|
61
52
|
self.check()
|
|
62
53
|
self._check_statistics_config(task_config)
|
|
@@ -65,7 +56,7 @@ class DebuggerConfig:
|
|
|
65
56
|
self.is_backward_kernel_dump = False
|
|
66
57
|
self._check_and_adjust_config_with_l2()
|
|
67
58
|
|
|
68
|
-
def
|
|
59
|
+
def check(self):
|
|
69
60
|
if self.task and self.task not in Const.TASK_LIST:
|
|
70
61
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
71
62
|
f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
|
|
@@ -78,22 +69,26 @@ class DebuggerConfig:
|
|
|
78
69
|
if not isinstance(self.async_dump, bool):
|
|
79
70
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
71
|
f"The parameters async_dump should be bool.")
|
|
81
|
-
if self.async_dump and self.task == Const.TENSOR:
|
|
82
|
-
if self.level == Const.LEVEL_DEBUG:
|
|
83
|
-
self.list = [] # async_dump + debug level case ignore list
|
|
84
|
-
if not self.list and self.level != Const.LEVEL_DEBUG:
|
|
85
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
86
|
-
f"The parameters async_dump is true in tensor task, the parameters list cannot be "
|
|
87
|
-
f"empty.")
|
|
88
72
|
if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
89
73
|
logger.warning_on_rank_0(
|
|
90
74
|
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
91
75
|
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
92
76
|
)
|
|
93
77
|
self.level = Const.LEVEL_MIX
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
78
|
+
if self.async_dump:
|
|
79
|
+
if self.task == Const.TENSOR:
|
|
80
|
+
if self.level == Const.LEVEL_DEBUG:
|
|
81
|
+
self.list = [] # async_dump + debug level case ignore list
|
|
82
|
+
if not self.list and self.level != Const.LEVEL_DEBUG:
|
|
83
|
+
raise MsprobeException(
|
|
84
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
85
|
+
f"The parameters async_dump is true in tensor task, the parameters list cannot be empty."
|
|
86
|
+
)
|
|
87
|
+
if self.summary_mode == Const.MD5:
|
|
88
|
+
raise MsprobeException(
|
|
89
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
90
|
+
f"The parameters async_dump is true, the parameters summary_mode cannot be md5."
|
|
91
|
+
)
|
|
97
92
|
return True
|
|
98
93
|
|
|
99
94
|
def check_model(self, instance, start_model, token_range=None):
|
|
@@ -102,7 +97,7 @@ class DebuggerConfig:
|
|
|
102
97
|
if token_range and not instance.model:
|
|
103
98
|
error_info = "The 'model' parameter must be provided when token_range is not None"
|
|
104
99
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
105
|
-
|
|
100
|
+
|
|
106
101
|
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
|
|
107
102
|
return
|
|
108
103
|
|
|
@@ -123,7 +118,7 @@ class DebuggerConfig:
|
|
|
123
118
|
break
|
|
124
119
|
if error_model is not None:
|
|
125
120
|
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
126
|
-
|
|
121
|
+
f"type, currently there is an unsupported {type(error_model)} type.")
|
|
127
122
|
raise MsprobeException(
|
|
128
123
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
129
124
|
else:
|
|
@@ -24,8 +24,11 @@ from msprobe.pytorch.common.log import logger
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def wrap_setup_backward_hook(func):
|
|
27
|
-
def requires_clone(tensor):
|
|
28
|
-
|
|
27
|
+
def requires_clone(tensor, need_check_leaf=False):
|
|
28
|
+
need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
|
|
29
|
+
if need_check_leaf:
|
|
30
|
+
need_clone &= tensor.grad_fn is not None
|
|
31
|
+
return need_clone
|
|
29
32
|
|
|
30
33
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
31
34
|
def parse_tensor(item, tensor_list):
|
|
@@ -39,20 +42,20 @@ def wrap_setup_backward_hook(func):
|
|
|
39
42
|
parse_tensor(value, tensor_list)
|
|
40
43
|
|
|
41
44
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
|
|
42
|
-
def rebuild_args(item, tensor_iter):
|
|
43
|
-
if requires_clone(item):
|
|
45
|
+
def rebuild_args(item, tensor_iter, need_check_leaf=False):
|
|
46
|
+
if requires_clone(item, need_check_leaf):
|
|
44
47
|
result = next(tensor_iter)
|
|
45
48
|
if hasattr(result, "_base") and result._base is not None:
|
|
46
49
|
if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
|
|
47
50
|
torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
|
|
48
|
-
return result
|
|
51
|
+
return result
|
|
49
52
|
if isinstance(item, list):
|
|
50
53
|
for index, value in enumerate(item):
|
|
51
|
-
item[index] = rebuild_args(value, tensor_iter)
|
|
54
|
+
item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
52
55
|
return item
|
|
53
56
|
if isinstance(item, dict):
|
|
54
57
|
for key, value in item.items():
|
|
55
|
-
item[key] = rebuild_args(value, tensor_iter)
|
|
58
|
+
item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
|
|
56
59
|
return item
|
|
57
60
|
if isinstance(item, tuple):
|
|
58
61
|
if hasattr(item, '_fields'):
|
|
@@ -21,25 +21,18 @@ import torch
|
|
|
21
21
|
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.runtime import Runtime
|
|
24
25
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
26
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
25
27
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
26
28
|
from msprobe.pytorch.common.log import logger
|
|
27
29
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
28
30
|
from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
|
|
29
31
|
|
|
30
32
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def checkpoint_without_early_stop(*args, **kwargs):
|
|
36
|
-
with set_checkpoint_early_stop(False):
|
|
37
|
-
return origin_checkpoint(*args, **kwargs)
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def replace_checkpoint():
|
|
41
|
-
if torch_version_above_or_equal_2:
|
|
42
|
-
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
33
|
+
torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
|
|
34
|
+
if torch_version_above_or_equal_21:
|
|
35
|
+
from torch.utils.checkpoint import _StopRecomputationError
|
|
43
36
|
|
|
44
37
|
|
|
45
38
|
def wrap_megatron_deallocate(func):
|
|
@@ -53,6 +46,27 @@ def wrap_megatron_deallocate(func):
|
|
|
53
46
|
return wrapper_func
|
|
54
47
|
|
|
55
48
|
|
|
49
|
+
def wrap_forward_with_hook_safety(module):
|
|
50
|
+
"""
|
|
51
|
+
包装模块的forward方法,确保异常时也执行forward_hook。
|
|
52
|
+
"""
|
|
53
|
+
original_forward = module.forward
|
|
54
|
+
|
|
55
|
+
def wrapped_forward(*args, **kwargs):
|
|
56
|
+
try:
|
|
57
|
+
output = original_forward(*args, **kwargs)
|
|
58
|
+
return output
|
|
59
|
+
except _StopRecomputationError as e:
|
|
60
|
+
exception_output = None
|
|
61
|
+
if len(module._forward_hooks.values()) > 0:
|
|
62
|
+
# msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook
|
|
63
|
+
hook_fn = list(module._forward_hooks.values())[0]
|
|
64
|
+
hook_fn(module, args, kwargs, exception_output)
|
|
65
|
+
raise e
|
|
66
|
+
if torch_version_above_or_equal_21:
|
|
67
|
+
module.forward = wrapped_forward
|
|
68
|
+
|
|
69
|
+
|
|
56
70
|
class ModuleProcesser:
|
|
57
71
|
module_queue = ModuleQueue()
|
|
58
72
|
module_count = {}
|
|
@@ -66,11 +80,12 @@ class ModuleProcesser:
|
|
|
66
80
|
def __init__(self, scope):
|
|
67
81
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
68
82
|
wrap_setup_input_output_hook()
|
|
69
|
-
replace_checkpoint()
|
|
70
83
|
try:
|
|
71
84
|
from megatron.core.pipeline_parallel import schedules
|
|
72
85
|
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
73
86
|
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
87
|
+
schedules.forward_step = wrap_megatron_step(schedules.forward_step)
|
|
88
|
+
schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
|
|
74
89
|
for module in list(sys.modules.values()):
|
|
75
90
|
if module.__name__ == 'schedules':
|
|
76
91
|
continue
|
|
@@ -155,6 +170,7 @@ class ModuleProcesser:
|
|
|
155
170
|
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
156
171
|
)
|
|
157
172
|
ModuleProcesser.module_with_backward_hook[prefix_name] = True
|
|
173
|
+
wrap_forward_with_hook_safety(module)
|
|
158
174
|
register_forward_pre_hook(module, forward_pre_hook)
|
|
159
175
|
|
|
160
176
|
def build_module_hook(self, module_name, build_data_hook):
|
|
@@ -163,6 +179,9 @@ class ModuleProcesser:
|
|
|
163
179
|
if kwargs is None:
|
|
164
180
|
kwargs = {}
|
|
165
181
|
|
|
182
|
+
if not Runtime.is_running:
|
|
183
|
+
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
184
|
+
|
|
166
185
|
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
167
186
|
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
168
187
|
|
|
@@ -243,14 +262,16 @@ class ModuleProcesser:
|
|
|
243
262
|
ModuleProcesser.module_stack[tid] = []
|
|
244
263
|
|
|
245
264
|
if self.module_stack[tid]:
|
|
246
|
-
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
|
|
265
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
|
|
266
|
+
else [self.module_stack[tid][-1], get_micro_step()]
|
|
247
267
|
else:
|
|
248
268
|
parent_name = ModuleProcesser.module_queue.find_last(full_name)
|
|
249
|
-
ModuleProcesser.module_node[full_name] = parent_name
|
|
269
|
+
ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
|
|
270
|
+
else [parent_name, get_micro_step()]
|
|
250
271
|
|
|
251
272
|
ModuleProcesser.module_queue.add_name(full_name)
|
|
252
273
|
ModuleProcesser.module_stack[tid].append(full_name)
|
|
253
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
274
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
254
275
|
if self.scope:
|
|
255
276
|
self.scope.begin_module(full_name)
|
|
256
277
|
|
|
@@ -258,14 +279,15 @@ class ModuleProcesser:
|
|
|
258
279
|
tid = threading.get_ident()
|
|
259
280
|
if torch_version_above_or_equal_2 or is_forward:
|
|
260
281
|
ModuleProcesser.module_queue.remove_name(full_name)
|
|
261
|
-
ModuleProcesser.api_parent_node[tid] = None
|
|
282
|
+
ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
|
|
262
283
|
if self.module_stack.get(tid):
|
|
263
284
|
ModuleProcesser.module_stack[tid].pop()
|
|
264
285
|
if self.module_stack.get(tid):
|
|
265
|
-
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
|
|
286
|
+
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
|
|
287
|
+
else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
|
|
266
288
|
if self.scope:
|
|
267
289
|
self.scope.end_module(full_name)
|
|
268
290
|
else:
|
|
269
291
|
if self.scope:
|
|
270
292
|
self.scope.begin_module(full_name)
|
|
271
|
-
ModuleProcesser.api_parent_node[tid] = full_name
|
|
293
|
+
ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
@@ -17,8 +17,8 @@ from abc import ABC
|
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.utils import replace_last_occurrence
|
|
20
21
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
|
-
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
22
22
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
23
23
|
DeviceType,
|
|
24
24
|
FuzzLevel,
|
|
@@ -37,6 +37,7 @@ from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class FreeBenchmarkCheck(ABC):
|
|
40
|
+
grad_saver_dict = {}
|
|
40
41
|
|
|
41
42
|
def __init__(self, config) -> None:
|
|
42
43
|
super().__init__()
|
|
@@ -68,7 +69,9 @@ class FreeBenchmarkCheck(ABC):
|
|
|
68
69
|
grad_saver.kwargs = kwargs
|
|
69
70
|
grad_saver.register_compare_func_for_inputs(args, data_processor)
|
|
70
71
|
grad_saver.cache_backward_input(args)
|
|
71
|
-
|
|
72
|
+
|
|
73
|
+
backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD)
|
|
74
|
+
FreeBenchmarkCheck.grad_saver_dict[backward_name] = grad_saver
|
|
72
75
|
|
|
73
76
|
def forward(self, name, module, args, kwargs, output):
|
|
74
77
|
if not self.config.fuzz_stage == Const.FORWARD:
|
|
@@ -92,16 +95,16 @@ class FreeBenchmarkCheck(ABC):
|
|
|
92
95
|
return perturbed_output, handler.get_unequal_rows()
|
|
93
96
|
|
|
94
97
|
def backward(self, name, module, grad_output):
|
|
95
|
-
|
|
96
98
|
if not self.config.fuzz_stage == Const.BACKWARD:
|
|
97
99
|
return
|
|
98
100
|
try:
|
|
99
|
-
grad_saver =
|
|
101
|
+
grad_saver = FreeBenchmarkCheck.grad_saver_dict[name]
|
|
100
102
|
except AttributeError:
|
|
101
103
|
logger.warning_on_rank_0(
|
|
102
104
|
f"[msprobe] Free benchmark: get grad saver failed. api_name:{name}"
|
|
103
105
|
)
|
|
104
106
|
return
|
|
107
|
+
del FreeBenchmarkCheck.grad_saver_dict[name]
|
|
105
108
|
|
|
106
109
|
_new_grad_output = grad_output
|
|
107
110
|
try:
|