mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -282,6 +282,8 @@ class Comparator:
|
|
|
282
282
|
result = []
|
|
283
283
|
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
284
284
|
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
285
|
+
check_op_str_pattern_valid(ms_op_name)
|
|
286
|
+
check_op_str_pattern_valid(bench_op_name)
|
|
285
287
|
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
286
288
|
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
287
289
|
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
@@ -311,9 +313,9 @@ class Comparator:
|
|
|
311
313
|
]
|
|
312
314
|
|
|
313
315
|
if self.dump_mode == Const.SUMMARY:
|
|
314
|
-
result_item = base_result_item + [" "] * 8
|
|
316
|
+
result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标
|
|
315
317
|
else:
|
|
316
|
-
result_item = base_result_item + [" "] *
|
|
318
|
+
result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标
|
|
317
319
|
|
|
318
320
|
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
319
321
|
result_item.extend(npu_summary_data)
|
|
@@ -329,8 +331,11 @@ class Comparator:
|
|
|
329
331
|
else:
|
|
330
332
|
result_item.append(CompareConst.NONE)
|
|
331
333
|
if self.dump_mode == Const.ALL:
|
|
332
|
-
|
|
334
|
+
ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None)
|
|
335
|
+
pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None)
|
|
336
|
+
result_item.append([ms_data_name, pt_data_name])
|
|
333
337
|
result.append(result_item)
|
|
338
|
+
logger.info(f"{ms_op_name}, {bench_op_name} compared.")
|
|
334
339
|
elif ms_op_name not in npu_ops_all:
|
|
335
340
|
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
336
341
|
elif bench_op_name not in npu_ops_all:
|
|
@@ -349,47 +354,48 @@ class Comparator:
|
|
|
349
354
|
result_df = self.make_result_table(result)
|
|
350
355
|
return result_df
|
|
351
356
|
|
|
352
|
-
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param
|
|
357
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
353
358
|
"""
|
|
354
359
|
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
355
360
|
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
356
361
|
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
357
362
|
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
358
|
-
:param bench_data: bench的dump数据中"data"字段
|
|
359
363
|
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
360
|
-
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt
|
|
364
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
|
|
361
365
|
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
362
366
|
"""
|
|
363
|
-
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
364
|
-
data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
|
|
365
367
|
error_file, relative_err, error_flag = None, None, False
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
368
|
+
|
|
369
|
+
data_name_pair = op_name_mapping_dict.get(npu_op_name)
|
|
370
|
+
npu_data_name = data_name_pair[0]
|
|
371
|
+
bench_data_name = data_name_pair[1]
|
|
372
|
+
|
|
373
|
+
if str(npu_data_name) == '-1': # 没有npu真实数据
|
|
374
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
375
|
+
elif str(bench_data_name) == '-1': # 没有bench真实数据
|
|
371
376
|
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
372
377
|
error_file = 'no_bench_data'
|
|
373
378
|
else:
|
|
379
|
+
npu_dir = input_param.get("npu_dump_data_dir")
|
|
380
|
+
bench_dir = input_param.get("bench_dump_data_dir")
|
|
374
381
|
try:
|
|
375
|
-
read_npy_data = getattr(self, "read_npy_data")
|
|
376
382
|
frame_name = getattr(self, "frame_name")
|
|
383
|
+
read_npy_data = getattr(self, "read_npy_data")
|
|
377
384
|
if frame_name == "MSComparator":
|
|
378
|
-
n_value = read_npy_data(
|
|
385
|
+
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
379
386
|
if self.cross_frame:
|
|
380
|
-
b_value = read_npy_data(
|
|
381
|
-
load_pt_file=True)
|
|
387
|
+
b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True)
|
|
382
388
|
else:
|
|
383
|
-
b_value = read_npy_data(
|
|
389
|
+
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
384
390
|
else:
|
|
385
|
-
n_value = read_npy_data(
|
|
386
|
-
b_value = read_npy_data(
|
|
391
|
+
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
392
|
+
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
387
393
|
except IOError as error:
|
|
388
394
|
error_file = error.filename
|
|
389
395
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
390
396
|
error_flag = True
|
|
391
397
|
except (FileCheckException, CompareException):
|
|
392
|
-
error_file =
|
|
398
|
+
error_file = npu_data_name
|
|
393
399
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
394
400
|
error_flag = True
|
|
395
401
|
|
|
@@ -427,7 +433,9 @@ class Comparator:
|
|
|
427
433
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
428
434
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
429
435
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
430
|
-
|
|
436
|
+
if os.path.exists(file_path):
|
|
437
|
+
logger.warning(f"{file_path} will be deleted.")
|
|
438
|
+
remove_path(file_path)
|
|
431
439
|
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
432
440
|
|
|
433
441
|
npu_json = input_param.get("npu_json_path")
|
|
@@ -456,21 +464,23 @@ class Comparator:
|
|
|
456
464
|
|
|
457
465
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
458
466
|
cos_result = []
|
|
467
|
+
euc_dist_result = []
|
|
459
468
|
max_err_result = []
|
|
460
469
|
max_relative_err_result = []
|
|
461
|
-
err_mess = []
|
|
462
470
|
one_thousand_err_ratio_result = []
|
|
463
471
|
five_thousand_err_ratio_result = []
|
|
472
|
+
err_mess = []
|
|
473
|
+
|
|
464
474
|
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
465
|
-
|
|
475
|
+
|
|
466
476
|
for i in range(len(result_df)):
|
|
467
477
|
npu_op_name = result_df.iloc[i, 0]
|
|
468
478
|
bench_op_name = result_df.iloc[i, 1]
|
|
469
479
|
if is_print_compare_log:
|
|
470
480
|
logger.info("start compare: {}".format(npu_op_name))
|
|
471
481
|
|
|
472
|
-
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg
|
|
473
|
-
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param
|
|
482
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
|
|
483
|
+
= self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
474
484
|
|
|
475
485
|
if is_print_compare_log:
|
|
476
486
|
logger.info(
|
|
@@ -479,71 +489,30 @@ class Comparator:
|
|
|
479
489
|
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
480
490
|
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
481
491
|
cos_result.append(cos_sim)
|
|
492
|
+
euc_dist_result.append(euc_dist)
|
|
482
493
|
max_err_result.append(max_abs_err)
|
|
483
494
|
max_relative_err_result.append(max_relative_err)
|
|
484
|
-
err_mess.append(err_msg)
|
|
485
495
|
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
486
496
|
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
497
|
+
err_mess.append(err_msg)
|
|
487
498
|
|
|
488
499
|
cr = ComparisonResult(
|
|
489
500
|
cos_result=cos_result,
|
|
501
|
+
euc_dist_result=euc_dist_result,
|
|
490
502
|
max_err_result=max_err_result,
|
|
491
503
|
max_relative_err_result=max_relative_err_result,
|
|
492
|
-
err_msgs=err_mess,
|
|
493
504
|
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
494
|
-
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
505
|
+
five_thousand_err_ratio_result=five_thousand_err_ratio_result,
|
|
506
|
+
err_msgs=err_mess
|
|
495
507
|
)
|
|
496
508
|
|
|
497
509
|
return _save_cmp_result(idx, cr, result_df, lock)
|
|
498
510
|
|
|
499
|
-
def do_multi_process(self,
|
|
511
|
+
def do_multi_process(self, input_param, result_df):
|
|
500
512
|
try:
|
|
501
|
-
result_df = _handle_multi_process(self.compare_ops,
|
|
513
|
+
result_df = _handle_multi_process(self.compare_ops, input_param, result_df,
|
|
502
514
|
multiprocessing.Manager().RLock())
|
|
503
515
|
return result_df
|
|
504
516
|
except ValueError as e:
|
|
505
517
|
logger.error('result dataframe is not found.')
|
|
506
518
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def get_bench_data_name(bench_op_name, bench_data):
|
|
510
|
-
bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
|
|
511
|
-
if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
|
|
512
|
-
bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
|
|
513
|
-
else:
|
|
514
|
-
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
515
|
-
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
516
|
-
return None
|
|
517
|
-
layers = bench_name_list[2].split(Const.SEP)
|
|
518
|
-
|
|
519
|
-
def _get(key, container):
|
|
520
|
-
if isinstance(container, dict):
|
|
521
|
-
return container.get(key)
|
|
522
|
-
if isinstance(container, list):
|
|
523
|
-
try:
|
|
524
|
-
return container[int(key)]
|
|
525
|
-
except (ValueError, IndexError):
|
|
526
|
-
return None
|
|
527
|
-
return None
|
|
528
|
-
|
|
529
|
-
def get_by_layer(container, params_grad=False):
|
|
530
|
-
data = container
|
|
531
|
-
# dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
|
|
532
|
-
if params_grad:
|
|
533
|
-
layers.append('0')
|
|
534
|
-
for layer in layers:
|
|
535
|
-
data = _get(layer, data)
|
|
536
|
-
return _get(CompareConst.DATA_NAME.lower(), data)
|
|
537
|
-
|
|
538
|
-
if Const.INPUT == bench_name_list[1]:
|
|
539
|
-
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
540
|
-
elif Const.KWARGS == bench_name_list[1]:
|
|
541
|
-
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
542
|
-
elif Const.OUTPUT == bench_name_list[1]:
|
|
543
|
-
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
544
|
-
elif Const.PARAMS == bench_name_list[1]:
|
|
545
|
-
return get_by_layer(bench_data_bundle.get(Const.PARAMS))
|
|
546
|
-
elif Const.PARAMS_GRAD == bench_name_list[1]:
|
|
547
|
-
return get_by_layer(bench_data_bundle, params_grad=True)
|
|
548
|
-
else:
|
|
549
|
-
return None
|
msprobe/core/compare/check.py
CHANGED
|
@@ -82,12 +82,8 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
82
82
|
f'should both be 2, please check!')
|
|
83
83
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
84
84
|
shape_match = npu_shape == bench_shape
|
|
85
|
-
type_match = npu_type == bench_type
|
|
86
|
-
|
|
87
|
-
if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
|
|
88
|
-
type_match = True
|
|
89
|
-
else:
|
|
90
|
-
type_match = False
|
|
85
|
+
type_match = ((npu_type == bench_type) or
|
|
86
|
+
any(npu_type in group and bench_type in group for group in CompareConst.DTYPE_MATCH_GROUPS))
|
|
91
87
|
struct_match = shape_match and type_match
|
|
92
88
|
if not struct_match:
|
|
93
89
|
return False
|
|
@@ -146,11 +146,13 @@ class HighlightRules:
|
|
|
146
146
|
}
|
|
147
147
|
|
|
148
148
|
# 用于比较输入和输出的规则
|
|
149
|
+
# 真实数据检查规则
|
|
149
150
|
compare_rules = {
|
|
150
151
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
151
152
|
"check_one_thousand_error": CheckOneThousandErrorRatio(),
|
|
152
153
|
"check_cosine_similarity": CheckCosineSimilarity()
|
|
153
154
|
}
|
|
155
|
+
# 统计量数据检查规则
|
|
154
156
|
summary_compare_rules = {
|
|
155
157
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
156
158
|
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
@@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml,
|
|
|
23
23
|
get_stack_construct_by_dump_json_path)
|
|
24
24
|
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
25
25
|
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
|
-
|
|
26
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class LayerTrie:
|
|
@@ -71,6 +71,7 @@ class LayerTrie:
|
|
|
71
71
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
72
|
save_yaml(file_path, result)
|
|
73
73
|
|
|
74
|
+
@recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100)
|
|
74
75
|
def convert_to_dict(self, node):
|
|
75
76
|
result = {}
|
|
76
77
|
result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,7 +21,8 @@ from functools import partial
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from tqdm import tqdm
|
|
23
23
|
|
|
24
|
-
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \
|
|
25
|
+
remove_path
|
|
25
26
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
26
27
|
from msprobe.core.common.utils import CompareException, add_time_with_xlsx
|
|
27
28
|
from msprobe.core.compare.utils import table_value_is_valid
|
|
@@ -63,6 +64,7 @@ def get_result_path(input_dir):
|
|
|
63
64
|
for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
|
|
64
65
|
filt_compare_result_path_list = []
|
|
65
66
|
for file_path in compare_result_path_list:
|
|
67
|
+
FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
66
68
|
file_name = os.path.basename(file_path)
|
|
67
69
|
if check_compare_result_name(file_name):
|
|
68
70
|
compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
|
|
@@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co
|
|
|
329
331
|
for i, df in enumerate(merge_df_list):
|
|
330
332
|
# merge_df_list中df与compare_index_list中compare_index一一对应
|
|
331
333
|
final_result_df_list.append((df, compare_index_list[i]))
|
|
334
|
+
|
|
335
|
+
if os.path.exists(output_path):
|
|
336
|
+
logger.warning(f"{output_path} will be deleted.")
|
|
337
|
+
remove_path(output_path)
|
|
332
338
|
save_excel(output_path, final_result_df_list)
|
|
333
339
|
logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
|
|
334
340
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,14 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
import multiprocessing
|
|
17
17
|
from dataclasses import dataclass
|
|
18
|
+
from functools import partial
|
|
19
|
+
|
|
18
20
|
import pandas as pd
|
|
19
21
|
from tqdm import tqdm
|
|
22
|
+
|
|
20
23
|
from msprobe.core.common.log import logger
|
|
21
24
|
from msprobe.core.common.utils import CompareException
|
|
22
25
|
from msprobe.core.common.const import CompareConst
|
|
23
26
|
|
|
24
27
|
|
|
25
|
-
def _handle_multi_process(func,
|
|
28
|
+
def _handle_multi_process(func, input_param, result_df, lock):
|
|
26
29
|
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
27
30
|
op_name_mapping_dict = read_dump_data(result_df)
|
|
28
31
|
|
|
@@ -44,7 +47,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
44
47
|
|
|
45
48
|
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
46
49
|
|
|
47
|
-
def update_progress(size, progress_lock):
|
|
50
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
48
51
|
with progress_lock:
|
|
49
52
|
progress_bar.update(size)
|
|
50
53
|
|
|
@@ -52,10 +55,12 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
52
55
|
idx = df_chunk_size * process_idx
|
|
53
56
|
chunk_size = len(df_chunk)
|
|
54
57
|
result = pool.apply_async(func,
|
|
55
|
-
args=(idx, op_name_mapping_dict, df_chunk, lock,
|
|
58
|
+
args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
|
|
56
59
|
error_callback=err_call,
|
|
57
|
-
callback=update_progress
|
|
60
|
+
callback=partial(update_progress, chunk_size, lock)
|
|
61
|
+
)
|
|
58
62
|
results.append(result)
|
|
63
|
+
|
|
59
64
|
final_results = [r.get() for r in results]
|
|
60
65
|
pool.close()
|
|
61
66
|
pool.join()
|
|
@@ -92,12 +97,12 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
92
97
|
def read_dump_data(result_df):
|
|
93
98
|
try:
|
|
94
99
|
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
95
|
-
|
|
100
|
+
dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
|
|
96
101
|
op_name_mapping_dict = {}
|
|
97
102
|
for index, _ in enumerate(npu_dump_name_list):
|
|
98
103
|
npu_dump_name = npu_dump_name_list[index]
|
|
99
|
-
|
|
100
|
-
op_name_mapping_dict[npu_dump_name] =
|
|
104
|
+
dump_tensor_pair = dump_tensor_pair_list[index]
|
|
105
|
+
op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
|
|
101
106
|
return op_name_mapping_dict
|
|
102
107
|
except ValueError as e:
|
|
103
108
|
logger.error('result dataframe is not found.')
|
|
@@ -110,11 +115,12 @@ def read_dump_data(result_df):
|
|
|
110
115
|
@dataclass
|
|
111
116
|
class ComparisonResult:
|
|
112
117
|
cos_result: list
|
|
118
|
+
euc_dist_result: list
|
|
113
119
|
max_err_result: list
|
|
114
120
|
max_relative_err_result: list
|
|
115
|
-
err_msgs: list
|
|
116
121
|
one_thousand_err_ratio_result: list
|
|
117
122
|
five_thousand_err_ratio_result: list
|
|
123
|
+
err_msgs: list
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
@@ -135,15 +141,16 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
|
135
141
|
for i, _ in enumerate(result.cos_result):
|
|
136
142
|
process_index = i + offset
|
|
137
143
|
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
144
|
+
result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
|
|
138
145
|
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
139
146
|
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
140
|
-
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
141
|
-
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
142
|
-
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
143
147
|
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
144
148
|
result.one_thousand_err_ratio_result)[i]
|
|
145
149
|
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
146
150
|
result.five_thousand_err_ratio_result)[i]
|
|
151
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
152
|
+
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
153
|
+
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
147
154
|
return result_df
|
|
148
155
|
except ValueError as e:
|
|
149
156
|
logger.error('result dataframe is not found.')
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
|
70
70
|
error_flag = True
|
|
71
71
|
return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
|
|
72
72
|
if not n_value.shape: # 判断数据是否为0维张量
|
|
73
|
-
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
|
|
73
|
+
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', "
|
|
74
74
|
f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
|
|
75
75
|
error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
|
|
76
76
|
return n_value, b_value, error_flag, err_msg
|
|
@@ -168,8 +168,9 @@ def statistics_data_check(result_dict):
|
|
|
168
168
|
|
|
169
169
|
class TensorComparisonBasic(abc.ABC):
|
|
170
170
|
"""NPU和bench中npy数据的比较模板"""
|
|
171
|
+
|
|
171
172
|
@abc.abstractmethod
|
|
172
|
-
def apply(self, n_value, b_value, relative_err):
|
|
173
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
173
174
|
raise NotImplementedError
|
|
174
175
|
|
|
175
176
|
|
|
@@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value):
|
|
|
190
191
|
|
|
191
192
|
class GetCosineSimilarity(TensorComparisonBasic):
|
|
192
193
|
"""计算cosine相似度"""
|
|
194
|
+
|
|
193
195
|
@staticmethod
|
|
194
196
|
def correct_data(result):
|
|
195
197
|
if result == CompareConst.NAN:
|
|
@@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
198
200
|
return round(float(result), 6)
|
|
199
201
|
return result
|
|
200
202
|
|
|
201
|
-
def apply(self, n_value, b_value, relative_err):
|
|
202
|
-
if
|
|
203
|
-
return CompareConst.UNSUPPORTED,
|
|
203
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
204
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
205
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
204
206
|
|
|
205
207
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
206
208
|
if len(n_value) == 1:
|
|
@@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
224
226
|
return result, ""
|
|
225
227
|
|
|
226
228
|
|
|
229
|
+
class GetEuclideanDistance(TensorComparisonBasic):
|
|
230
|
+
"""计算欧式距离"""
|
|
231
|
+
|
|
232
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
233
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
234
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
235
|
+
|
|
236
|
+
distance = np.linalg.norm(n_value - b_value, ord=2)
|
|
237
|
+
|
|
238
|
+
return distance, ""
|
|
239
|
+
|
|
240
|
+
|
|
227
241
|
class GetMaxAbsErr(TensorComparisonBasic):
|
|
228
242
|
"""计算最大绝对误差"""
|
|
229
|
-
|
|
243
|
+
|
|
244
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
230
245
|
temp_res = n_value - b_value
|
|
231
246
|
max_value = np.max(np.abs(temp_res))
|
|
232
247
|
if np.isnan(max_value):
|
|
@@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic):
|
|
|
237
252
|
|
|
238
253
|
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
239
254
|
"""计算最大相对误差"""
|
|
240
|
-
|
|
255
|
+
|
|
256
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
241
257
|
max_relative_err = np.max(np.abs(relative_err))
|
|
242
258
|
if np.isnan(max_relative_err):
|
|
243
259
|
msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
|
|
@@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic):
|
|
|
247
263
|
|
|
248
264
|
class GetErrRatio(TensorComparisonBasic):
|
|
249
265
|
"""计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
|
|
266
|
+
|
|
250
267
|
def __init__(self, threshold):
|
|
251
268
|
self.threshold = threshold
|
|
252
269
|
|
|
253
|
-
def apply(self, n_value, b_value, relative_err):
|
|
254
|
-
if
|
|
255
|
-
return CompareConst.UNSUPPORTED,
|
|
270
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
271
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
272
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
256
273
|
|
|
257
274
|
if not np.size(relative_err):
|
|
258
275
|
return CompareConst.NAN, ""
|
|
@@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic):
|
|
|
264
281
|
class CompareOps:
|
|
265
282
|
compare_ops = {
|
|
266
283
|
"cosine_similarity": GetCosineSimilarity(),
|
|
284
|
+
"euclidean_distance": GetEuclideanDistance(),
|
|
267
285
|
"max_abs_error": GetMaxAbsErr(),
|
|
268
286
|
"max_relative_error": GetMaxRelativeErr(),
|
|
269
287
|
"one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
|
|
@@ -295,7 +313,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
|
|
|
295
313
|
n_value, b_value = reshape_value(n_value, b_value)
|
|
296
314
|
|
|
297
315
|
for op in CompareOps.compare_ops.values():
|
|
298
|
-
result, msg = op.apply(n_value, b_value, relative_err)
|
|
316
|
+
result, msg = op.apply(n_value, b_value, relative_err, err_msg)
|
|
299
317
|
result_list.append(result)
|
|
300
318
|
err_msg += msg
|
|
301
319
|
return result_list, err_msg
|
msprobe/core/compare/utils.py
CHANGED
|
@@ -285,9 +285,9 @@ def result_item_init(n_info, b_info, dump_mode):
|
|
|
285
285
|
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
286
286
|
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
287
287
|
elif dump_mode == Const.SUMMARY:
|
|
288
|
-
result_item.extend([" "] * 8)
|
|
288
|
+
result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
|
|
289
289
|
else:
|
|
290
|
-
result_item.extend([" "] *
|
|
290
|
+
result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
|
|
291
291
|
else:
|
|
292
292
|
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
293
293
|
f"npu_info_struct is {n_info.struct}\n" \
|
|
@@ -321,8 +321,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
321
321
|
has_stack = npu_stack_info and bench_stack_info
|
|
322
322
|
|
|
323
323
|
if dump_mode == Const.ALL:
|
|
324
|
-
|
|
325
|
-
|
|
324
|
+
npu_data_name_list = n_dict.get("data_name", None)
|
|
325
|
+
bench_data_name_list = b_dict.get("data_name", None)
|
|
326
326
|
|
|
327
327
|
for index in range(min_len):
|
|
328
328
|
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
@@ -353,7 +353,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
353
353
|
result_item.append(err_msg)
|
|
354
354
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
355
355
|
if dump_mode == Const.ALL:
|
|
356
|
-
|
|
356
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
357
|
+
bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
|
|
358
|
+
result_item.append([npu_data_name, bench_data_name])
|
|
357
359
|
|
|
358
360
|
result.append(result_item)
|
|
359
361
|
|
|
@@ -371,7 +373,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
371
373
|
continue
|
|
372
374
|
result_item = [
|
|
373
375
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
374
|
-
" ", " ", " ", " ", " "
|
|
376
|
+
" ", " ", " ", " ", " ", " "
|
|
375
377
|
]
|
|
376
378
|
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
377
379
|
result_item.extend(summary_data)
|
|
@@ -388,7 +390,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
388
390
|
result_item.append(err_msg)
|
|
389
391
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
390
392
|
if dump_mode == Const.ALL:
|
|
391
|
-
|
|
393
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
394
|
+
result_item.append([npu_data_name, "-1"])
|
|
392
395
|
|
|
393
396
|
result.append(result_item)
|
|
394
397
|
|
|
@@ -453,9 +456,9 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
453
456
|
result.append(result_item)
|
|
454
457
|
continue
|
|
455
458
|
if dump_mode == Const.SUMMARY:
|
|
456
|
-
result_item.extend([CompareConst.N_A] * 8)
|
|
459
|
+
result_item.extend([CompareConst.N_A] * 8) # 8个统计量数据情况的比对指标
|
|
457
460
|
if dump_mode == Const.ALL:
|
|
458
|
-
result_item.extend([CompareConst.N_A] *
|
|
461
|
+
result_item.extend([CompareConst.N_A] * 6) # 6个真实数据情况的比对指标
|
|
459
462
|
|
|
460
463
|
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
461
464
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
@@ -467,7 +470,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
467
470
|
result_item.append(err_msg)
|
|
468
471
|
append_stack_info(result_item, npu_stack_info, index)
|
|
469
472
|
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
470
|
-
result_item.extend(["-1"])
|
|
473
|
+
result_item.extend([["-1", "-1"]])
|
|
471
474
|
result.append(result_item)
|
|
472
475
|
|
|
473
476
|
|
|
@@ -542,10 +545,17 @@ def get_name_and_state(name):
|
|
|
542
545
|
|
|
543
546
|
state type: input, output, kwargs, parameters, parameters_grad
|
|
544
547
|
"""
|
|
548
|
+
if not isinstance(name, str):
|
|
549
|
+
logger.error(f'Invalid name: {name}, type should be string, please check.')
|
|
550
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
551
|
+
|
|
545
552
|
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
546
553
|
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
547
554
|
|
|
548
555
|
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
556
|
+
if len(split) < 3:
|
|
557
|
+
logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
|
|
558
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
549
559
|
api = f'{split[0]}.{split[1]}.'
|
|
550
560
|
state_str = split[2]
|
|
551
561
|
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|