mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -282,6 +282,8 @@ class Comparator:
|
|
|
282
282
|
result = []
|
|
283
283
|
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
284
284
|
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
285
|
+
check_op_str_pattern_valid(ms_op_name)
|
|
286
|
+
check_op_str_pattern_valid(bench_op_name)
|
|
285
287
|
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
286
288
|
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
287
289
|
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
@@ -311,9 +313,9 @@ class Comparator:
|
|
|
311
313
|
]
|
|
312
314
|
|
|
313
315
|
if self.dump_mode == Const.SUMMARY:
|
|
314
|
-
result_item = base_result_item + [" "] * 8
|
|
316
|
+
result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标
|
|
315
317
|
else:
|
|
316
|
-
result_item = base_result_item + [" "] *
|
|
318
|
+
result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标
|
|
317
319
|
|
|
318
320
|
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
319
321
|
result_item.extend(npu_summary_data)
|
|
@@ -329,8 +331,11 @@ class Comparator:
|
|
|
329
331
|
else:
|
|
330
332
|
result_item.append(CompareConst.NONE)
|
|
331
333
|
if self.dump_mode == Const.ALL:
|
|
332
|
-
|
|
334
|
+
ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None)
|
|
335
|
+
pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None)
|
|
336
|
+
result_item.append([ms_data_name, pt_data_name])
|
|
333
337
|
result.append(result_item)
|
|
338
|
+
logger.info(f"{ms_op_name}, {bench_op_name} compared.")
|
|
334
339
|
elif ms_op_name not in npu_ops_all:
|
|
335
340
|
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
336
341
|
elif bench_op_name not in npu_ops_all:
|
|
@@ -349,47 +354,48 @@ class Comparator:
|
|
|
349
354
|
result_df = self.make_result_table(result)
|
|
350
355
|
return result_df
|
|
351
356
|
|
|
352
|
-
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param
|
|
357
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
353
358
|
"""
|
|
354
359
|
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
355
360
|
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
356
361
|
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
357
362
|
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
358
|
-
:param bench_data: bench的dump数据中"data"字段
|
|
359
363
|
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
360
|
-
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt
|
|
364
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
|
|
361
365
|
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
362
366
|
"""
|
|
363
|
-
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
364
|
-
data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
|
|
365
367
|
error_file, relative_err, error_flag = None, None, False
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
368
|
+
|
|
369
|
+
data_name_pair = op_name_mapping_dict.get(npu_op_name)
|
|
370
|
+
npu_data_name = data_name_pair[0]
|
|
371
|
+
bench_data_name = data_name_pair[1]
|
|
372
|
+
|
|
373
|
+
if str(npu_data_name) == '-1': # 没有npu真实数据
|
|
374
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
375
|
+
elif str(bench_data_name) == '-1': # 没有bench真实数据
|
|
371
376
|
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
372
377
|
error_file = 'no_bench_data'
|
|
373
378
|
else:
|
|
379
|
+
npu_dir = input_param.get("npu_dump_data_dir")
|
|
380
|
+
bench_dir = input_param.get("bench_dump_data_dir")
|
|
374
381
|
try:
|
|
375
|
-
read_npy_data = getattr(self, "read_npy_data")
|
|
376
382
|
frame_name = getattr(self, "frame_name")
|
|
383
|
+
read_npy_data = getattr(self, "read_npy_data")
|
|
377
384
|
if frame_name == "MSComparator":
|
|
378
|
-
n_value = read_npy_data(
|
|
385
|
+
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
379
386
|
if self.cross_frame:
|
|
380
|
-
b_value = read_npy_data(
|
|
381
|
-
load_pt_file=True)
|
|
387
|
+
b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True)
|
|
382
388
|
else:
|
|
383
|
-
b_value = read_npy_data(
|
|
389
|
+
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
384
390
|
else:
|
|
385
|
-
n_value = read_npy_data(
|
|
386
|
-
b_value = read_npy_data(
|
|
391
|
+
n_value = read_npy_data(npu_dir, npu_data_name)
|
|
392
|
+
b_value = read_npy_data(bench_dir, bench_data_name)
|
|
387
393
|
except IOError as error:
|
|
388
394
|
error_file = error.filename
|
|
389
395
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
390
396
|
error_flag = True
|
|
391
397
|
except (FileCheckException, CompareException):
|
|
392
|
-
error_file =
|
|
398
|
+
error_file = npu_data_name
|
|
393
399
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
394
400
|
error_flag = True
|
|
395
401
|
|
|
@@ -427,7 +433,9 @@ class Comparator:
|
|
|
427
433
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
428
434
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
429
435
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
430
|
-
|
|
436
|
+
if os.path.exists(file_path):
|
|
437
|
+
logger.warning(f"{file_path} will be deleted.")
|
|
438
|
+
remove_path(file_path)
|
|
431
439
|
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
432
440
|
|
|
433
441
|
npu_json = input_param.get("npu_json_path")
|
|
@@ -456,21 +464,23 @@ class Comparator:
|
|
|
456
464
|
|
|
457
465
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
458
466
|
cos_result = []
|
|
467
|
+
euc_dist_result = []
|
|
459
468
|
max_err_result = []
|
|
460
469
|
max_relative_err_result = []
|
|
461
|
-
err_mess = []
|
|
462
470
|
one_thousand_err_ratio_result = []
|
|
463
471
|
five_thousand_err_ratio_result = []
|
|
472
|
+
err_mess = []
|
|
473
|
+
|
|
464
474
|
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
465
|
-
|
|
475
|
+
|
|
466
476
|
for i in range(len(result_df)):
|
|
467
477
|
npu_op_name = result_df.iloc[i, 0]
|
|
468
478
|
bench_op_name = result_df.iloc[i, 1]
|
|
469
479
|
if is_print_compare_log:
|
|
470
480
|
logger.info("start compare: {}".format(npu_op_name))
|
|
471
481
|
|
|
472
|
-
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg
|
|
473
|
-
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param
|
|
482
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
|
|
483
|
+
= self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
474
484
|
|
|
475
485
|
if is_print_compare_log:
|
|
476
486
|
logger.info(
|
|
@@ -479,71 +489,30 @@ class Comparator:
|
|
|
479
489
|
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
480
490
|
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
481
491
|
cos_result.append(cos_sim)
|
|
492
|
+
euc_dist_result.append(euc_dist)
|
|
482
493
|
max_err_result.append(max_abs_err)
|
|
483
494
|
max_relative_err_result.append(max_relative_err)
|
|
484
|
-
err_mess.append(err_msg)
|
|
485
495
|
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
486
496
|
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
497
|
+
err_mess.append(err_msg)
|
|
487
498
|
|
|
488
499
|
cr = ComparisonResult(
|
|
489
500
|
cos_result=cos_result,
|
|
501
|
+
euc_dist_result=euc_dist_result,
|
|
490
502
|
max_err_result=max_err_result,
|
|
491
503
|
max_relative_err_result=max_relative_err_result,
|
|
492
|
-
err_msgs=err_mess,
|
|
493
504
|
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
494
|
-
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
505
|
+
five_thousand_err_ratio_result=five_thousand_err_ratio_result,
|
|
506
|
+
err_msgs=err_mess
|
|
495
507
|
)
|
|
496
508
|
|
|
497
509
|
return _save_cmp_result(idx, cr, result_df, lock)
|
|
498
510
|
|
|
499
|
-
def do_multi_process(self,
|
|
511
|
+
def do_multi_process(self, input_param, result_df):
|
|
500
512
|
try:
|
|
501
|
-
result_df = _handle_multi_process(self.compare_ops,
|
|
513
|
+
result_df = _handle_multi_process(self.compare_ops, input_param, result_df,
|
|
502
514
|
multiprocessing.Manager().RLock())
|
|
503
515
|
return result_df
|
|
504
516
|
except ValueError as e:
|
|
505
517
|
logger.error('result dataframe is not found.')
|
|
506
518
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def get_bench_data_name(bench_op_name, bench_data):
|
|
510
|
-
bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
|
|
511
|
-
if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
|
|
512
|
-
bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
|
|
513
|
-
else:
|
|
514
|
-
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
515
|
-
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
516
|
-
return None
|
|
517
|
-
layers = bench_name_list[2].split(Const.SEP)
|
|
518
|
-
|
|
519
|
-
def _get(key, container):
|
|
520
|
-
if isinstance(container, dict):
|
|
521
|
-
return container.get(key)
|
|
522
|
-
if isinstance(container, list):
|
|
523
|
-
try:
|
|
524
|
-
return container[int(key)]
|
|
525
|
-
except (ValueError, IndexError):
|
|
526
|
-
return None
|
|
527
|
-
return None
|
|
528
|
-
|
|
529
|
-
def get_by_layer(container, params_grad=False):
|
|
530
|
-
data = container
|
|
531
|
-
# dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
|
|
532
|
-
if params_grad:
|
|
533
|
-
layers.append('0')
|
|
534
|
-
for layer in layers:
|
|
535
|
-
data = _get(layer, data)
|
|
536
|
-
return _get(CompareConst.DATA_NAME.lower(), data)
|
|
537
|
-
|
|
538
|
-
if Const.INPUT == bench_name_list[1]:
|
|
539
|
-
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
540
|
-
elif Const.KWARGS == bench_name_list[1]:
|
|
541
|
-
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
542
|
-
elif Const.OUTPUT == bench_name_list[1]:
|
|
543
|
-
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
544
|
-
elif Const.PARAMS == bench_name_list[1]:
|
|
545
|
-
return get_by_layer(bench_data_bundle.get(Const.PARAMS))
|
|
546
|
-
elif Const.PARAMS_GRAD == bench_name_list[1]:
|
|
547
|
-
return get_by_layer(bench_data_bundle, params_grad=True)
|
|
548
|
-
else:
|
|
549
|
-
return None
|
msprobe/core/compare/check.py
CHANGED
|
@@ -82,12 +82,8 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
82
82
|
f'should both be 2, please check!')
|
|
83
83
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
84
84
|
shape_match = npu_shape == bench_shape
|
|
85
|
-
type_match = npu_type == bench_type
|
|
86
|
-
|
|
87
|
-
if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
|
|
88
|
-
type_match = True
|
|
89
|
-
else:
|
|
90
|
-
type_match = False
|
|
85
|
+
type_match = ((npu_type == bench_type) or
|
|
86
|
+
any(npu_type in group and bench_type in group for group in CompareConst.DTYPE_MATCH_GROUPS))
|
|
91
87
|
struct_match = shape_match and type_match
|
|
92
88
|
if not struct_match:
|
|
93
89
|
return False
|
|
@@ -146,11 +146,13 @@ class HighlightRules:
|
|
|
146
146
|
}
|
|
147
147
|
|
|
148
148
|
# 用于比较输入和输出的规则
|
|
149
|
+
# 真实数据检查规则
|
|
149
150
|
compare_rules = {
|
|
150
151
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
151
152
|
"check_one_thousand_error": CheckOneThousandErrorRatio(),
|
|
152
153
|
"check_cosine_similarity": CheckCosineSimilarity()
|
|
153
154
|
}
|
|
155
|
+
# 统计量数据检查规则
|
|
154
156
|
summary_compare_rules = {
|
|
155
157
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
156
158
|
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
@@ -112,7 +112,7 @@ class DumpDataItem:
|
|
|
112
112
|
self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
|
|
113
113
|
else:
|
|
114
114
|
self.layer_scope = Const.TOP_LAYER
|
|
115
|
-
if construct_info:
|
|
115
|
+
if construct_info and Const.SEP in construct_info:
|
|
116
116
|
construct_list = construct_info.split(Const.SEP)
|
|
117
117
|
if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
|
|
118
118
|
logger.error(
|
|
@@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml,
|
|
|
23
23
|
get_stack_construct_by_dump_json_path)
|
|
24
24
|
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
25
25
|
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
|
-
|
|
26
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class LayerTrie:
|
|
@@ -71,6 +71,7 @@ class LayerTrie:
|
|
|
71
71
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
72
|
save_yaml(file_path, result)
|
|
73
73
|
|
|
74
|
+
@recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100)
|
|
74
75
|
def convert_to_dict(self, node):
|
|
75
76
|
result = {}
|
|
76
77
|
result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,10 +21,12 @@ from functools import partial
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from tqdm import tqdm
|
|
23
23
|
|
|
24
|
-
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \
|
|
25
|
+
remove_path
|
|
25
26
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
26
27
|
from msprobe.core.common.utils import CompareException, add_time_with_xlsx
|
|
27
28
|
from msprobe.core.compare.utils import table_value_is_valid
|
|
29
|
+
from msprobe.core.compare.merge_result.utils import replace_compare_index_dict, check_config
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def check_compare_result_name(file_name):
|
|
@@ -62,6 +64,7 @@ def get_result_path(input_dir):
|
|
|
62
64
|
for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
|
|
63
65
|
filt_compare_result_path_list = []
|
|
64
66
|
for file_path in compare_result_path_list:
|
|
67
|
+
FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
65
68
|
file_name = os.path.basename(file_path)
|
|
66
69
|
if check_compare_result_name(file_name):
|
|
67
70
|
compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
|
|
@@ -170,6 +173,8 @@ def search_api_index_result(api_list, compare_index_list, result_df, rank_num, c
|
|
|
170
173
|
table_value_check(index_value)
|
|
171
174
|
api_index_dict.setdefault(api_full_name, {})[rank_num] = index_value # update api_index_dict
|
|
172
175
|
compare_index_dict[compare_index] = api_index_dict
|
|
176
|
+
|
|
177
|
+
compare_index_dict = replace_compare_index_dict(compare_index_dict, compare_index_list, rank_num)
|
|
173
178
|
return compare_index_dict
|
|
174
179
|
|
|
175
180
|
|
|
@@ -203,10 +208,13 @@ def result_process(compare_result_path_list, api_list):
|
|
|
203
208
|
compare_index_list = check_index_dump_mode_consistent(dump_mode, rank_num)
|
|
204
209
|
if len(compare_index_list) == 0:
|
|
205
210
|
return [], [], []
|
|
206
|
-
|
|
211
|
+
compare_index_list.extend([CompareConst.NPU_MAX, CompareConst.BENCH_MAX])
|
|
212
|
+
compare_index_dict = search_api_index_result(api_list, compare_index_list,
|
|
207
213
|
result_df, rank_num, compare_index_dict)
|
|
208
214
|
compare_index_dict_list.append(compare_index_dict)
|
|
209
215
|
rank_num_list.append(rank_num)
|
|
216
|
+
compare_index_list.pop()
|
|
217
|
+
compare_index_list.pop()
|
|
210
218
|
else:
|
|
211
219
|
logger.warning(f"Rank{rank_num} compare result is empty and will not shown in merged result.")
|
|
212
220
|
|
|
@@ -323,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co
|
|
|
323
331
|
for i, df in enumerate(merge_df_list):
|
|
324
332
|
# merge_df_list中df与compare_index_list中compare_index一一对应
|
|
325
333
|
final_result_df_list.append((df, compare_index_list[i]))
|
|
334
|
+
|
|
335
|
+
if os.path.exists(output_path):
|
|
336
|
+
logger.warning(f"{output_path} will be deleted.")
|
|
337
|
+
remove_path(output_path)
|
|
326
338
|
save_excel(output_path, final_result_df_list)
|
|
327
339
|
logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
|
|
328
340
|
|
|
@@ -362,13 +374,8 @@ def merge_result(input_dir, output_dir, config_path):
|
|
|
362
374
|
compare_result_path_list = get_result_path(input_dir) # 获得的input_dir中所有比对结果件的全路径,数量少于2,便提示退出
|
|
363
375
|
|
|
364
376
|
config = load_yaml(config_path)
|
|
365
|
-
|
|
366
|
-
logger.error('config.yaml is empty, please check.')
|
|
367
|
-
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
377
|
+
config = check_config(config)
|
|
368
378
|
api_list = config.get('api')
|
|
369
|
-
if not api_list:
|
|
370
|
-
logger.error('The APIs required to merge data were not found')
|
|
371
|
-
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
372
379
|
|
|
373
380
|
# 初始化共享全局变量share_compare_index_list
|
|
374
381
|
initialize_compare_index(config)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.common.const import CompareConst
|
|
17
|
+
from msprobe.core.common.file_utils import logger
|
|
18
|
+
from msprobe.core.common.utils import CompareException
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def replace_compare_index_dict(compare_index_dict, compare_index_list, rank_num):
|
|
22
|
+
"""
|
|
23
|
+
比对指标值为N/A、unsupported、Nan,将比对指标值替换成NPU max 和 Bench max(几个统计量相同)
|
|
24
|
+
|
|
25
|
+
示例:
|
|
26
|
+
Distributed.all_reduce.0.forward.output.group的比对指标值是N/A
|
|
27
|
+
替换后:
|
|
28
|
+
比对指标值为:
|
|
29
|
+
NPU: tp-0-1-2-3
|
|
30
|
+
Bench: tp-0-1-2-3
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
if CompareConst.NPU_MAX not in compare_index_dict or CompareConst.BENCH_MAX not in compare_index_dict:
|
|
34
|
+
compare_index_dict.pop(CompareConst.NPU_MAX, None)
|
|
35
|
+
compare_index_dict.pop(CompareConst.BENCH_MAX, None)
|
|
36
|
+
return compare_index_dict
|
|
37
|
+
|
|
38
|
+
# 遍历比对指标列表,排除最后两个指标NPU max, Bench max
|
|
39
|
+
for compare_index in compare_index_list[:-2]:
|
|
40
|
+
op_name_index_dict = compare_index_dict[compare_index]
|
|
41
|
+
# 遍历op_item名称和对应的比对指标值
|
|
42
|
+
for op_name, index_value in op_name_index_dict.items():
|
|
43
|
+
npu_max = compare_index_dict[CompareConst.NPU_MAX][op_name][rank_num]
|
|
44
|
+
bench_max = compare_index_dict[CompareConst.BENCH_MAX][op_name][rank_num]
|
|
45
|
+
# 如果当前比对指标值是N/A、unsupported、Nan,并且NPU和Bench的最大值是类型相同,进行替换
|
|
46
|
+
if index_value[rank_num] in [CompareConst.N_A, CompareConst.UNSUPPORTED, CompareConst.NAN]:
|
|
47
|
+
compare_index_dict[compare_index][op_name][rank_num] = f'NPU:{str(npu_max)} Bench:{str(bench_max)}'
|
|
48
|
+
|
|
49
|
+
# 删除NPU_MAX和BENCH_MAX
|
|
50
|
+
compare_index_dict.pop(CompareConst.NPU_MAX, None)
|
|
51
|
+
compare_index_dict.pop(CompareConst.BENCH_MAX, None)
|
|
52
|
+
return compare_index_dict
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def check_config(config):
|
|
56
|
+
"""
|
|
57
|
+
config.yaml 内容检查
|
|
58
|
+
Args: config:
|
|
59
|
+
Returns: config
|
|
60
|
+
"""
|
|
61
|
+
if not config:
|
|
62
|
+
logger.error('config.yaml is empty, please check.')
|
|
63
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
64
|
+
|
|
65
|
+
api_list = config.get('api')
|
|
66
|
+
if not api_list:
|
|
67
|
+
logger.error('The APIs required to merge data were not found.')
|
|
68
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
69
|
+
if not isinstance(api_list, list):
|
|
70
|
+
logger.error("The config format of 'api' is incorrect, please check.")
|
|
71
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
72
|
+
|
|
73
|
+
compare_index_list = config.get('compare_index', [])
|
|
74
|
+
if compare_index_list is None:
|
|
75
|
+
compare_index_list = []
|
|
76
|
+
config['compare_index'] = compare_index_list
|
|
77
|
+
if not isinstance(compare_index_list, list):
|
|
78
|
+
logger.error("The config format of 'compare_index' is incorrect, please check.")
|
|
79
|
+
raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
|
|
80
|
+
|
|
81
|
+
return config
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,14 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
import multiprocessing
|
|
17
17
|
from dataclasses import dataclass
|
|
18
|
+
from functools import partial
|
|
19
|
+
|
|
18
20
|
import pandas as pd
|
|
19
21
|
from tqdm import tqdm
|
|
22
|
+
|
|
20
23
|
from msprobe.core.common.log import logger
|
|
21
24
|
from msprobe.core.common.utils import CompareException
|
|
22
25
|
from msprobe.core.common.const import CompareConst
|
|
23
26
|
|
|
24
27
|
|
|
25
|
-
def _handle_multi_process(func,
|
|
28
|
+
def _handle_multi_process(func, input_param, result_df, lock):
|
|
26
29
|
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
27
30
|
op_name_mapping_dict = read_dump_data(result_df)
|
|
28
31
|
|
|
@@ -44,7 +47,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
44
47
|
|
|
45
48
|
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
46
49
|
|
|
47
|
-
def update_progress(size, progress_lock):
|
|
50
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
48
51
|
with progress_lock:
|
|
49
52
|
progress_bar.update(size)
|
|
50
53
|
|
|
@@ -52,10 +55,12 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
52
55
|
idx = df_chunk_size * process_idx
|
|
53
56
|
chunk_size = len(df_chunk)
|
|
54
57
|
result = pool.apply_async(func,
|
|
55
|
-
args=(idx, op_name_mapping_dict, df_chunk, lock,
|
|
58
|
+
args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
|
|
56
59
|
error_callback=err_call,
|
|
57
|
-
callback=update_progress
|
|
60
|
+
callback=partial(update_progress, chunk_size, lock)
|
|
61
|
+
)
|
|
58
62
|
results.append(result)
|
|
63
|
+
|
|
59
64
|
final_results = [r.get() for r in results]
|
|
60
65
|
pool.close()
|
|
61
66
|
pool.join()
|
|
@@ -92,12 +97,12 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
|
92
97
|
def read_dump_data(result_df):
|
|
93
98
|
try:
|
|
94
99
|
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
95
|
-
|
|
100
|
+
dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
|
|
96
101
|
op_name_mapping_dict = {}
|
|
97
102
|
for index, _ in enumerate(npu_dump_name_list):
|
|
98
103
|
npu_dump_name = npu_dump_name_list[index]
|
|
99
|
-
|
|
100
|
-
op_name_mapping_dict[npu_dump_name] =
|
|
104
|
+
dump_tensor_pair = dump_tensor_pair_list[index]
|
|
105
|
+
op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
|
|
101
106
|
return op_name_mapping_dict
|
|
102
107
|
except ValueError as e:
|
|
103
108
|
logger.error('result dataframe is not found.')
|
|
@@ -110,11 +115,12 @@ def read_dump_data(result_df):
|
|
|
110
115
|
@dataclass
|
|
111
116
|
class ComparisonResult:
|
|
112
117
|
cos_result: list
|
|
118
|
+
euc_dist_result: list
|
|
113
119
|
max_err_result: list
|
|
114
120
|
max_relative_err_result: list
|
|
115
|
-
err_msgs: list
|
|
116
121
|
one_thousand_err_ratio_result: list
|
|
117
122
|
five_thousand_err_ratio_result: list
|
|
123
|
+
err_msgs: list
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
@@ -135,15 +141,16 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
|
135
141
|
for i, _ in enumerate(result.cos_result):
|
|
136
142
|
process_index = i + offset
|
|
137
143
|
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
144
|
+
result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
|
|
138
145
|
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
139
146
|
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
140
|
-
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
141
|
-
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
142
|
-
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
143
147
|
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
144
148
|
result.one_thousand_err_ratio_result)[i]
|
|
145
149
|
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
146
150
|
result.five_thousand_err_ratio_result)[i]
|
|
151
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
152
|
+
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
153
|
+
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
147
154
|
return result_df
|
|
148
155
|
except ValueError as e:
|
|
149
156
|
logger.error('result dataframe is not found.')
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
|
|
|
70
70
|
error_flag = True
|
|
71
71
|
return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
|
|
72
72
|
if not n_value.shape: # 判断数据是否为0维张量
|
|
73
|
-
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
|
|
73
|
+
err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', "
|
|
74
74
|
f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
|
|
75
75
|
error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
|
|
76
76
|
return n_value, b_value, error_flag, err_msg
|
|
@@ -168,8 +168,9 @@ def statistics_data_check(result_dict):
|
|
|
168
168
|
|
|
169
169
|
class TensorComparisonBasic(abc.ABC):
|
|
170
170
|
"""NPU和bench中npy数据的比较模板"""
|
|
171
|
+
|
|
171
172
|
@abc.abstractmethod
|
|
172
|
-
def apply(self, n_value, b_value, relative_err):
|
|
173
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
173
174
|
raise NotImplementedError
|
|
174
175
|
|
|
175
176
|
|
|
@@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value):
|
|
|
190
191
|
|
|
191
192
|
class GetCosineSimilarity(TensorComparisonBasic):
|
|
192
193
|
"""计算cosine相似度"""
|
|
194
|
+
|
|
193
195
|
@staticmethod
|
|
194
196
|
def correct_data(result):
|
|
195
197
|
if result == CompareConst.NAN:
|
|
@@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
198
200
|
return round(float(result), 6)
|
|
199
201
|
return result
|
|
200
202
|
|
|
201
|
-
def apply(self, n_value, b_value, relative_err):
|
|
202
|
-
if
|
|
203
|
-
return CompareConst.UNSUPPORTED,
|
|
203
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
204
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
205
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
204
206
|
|
|
205
207
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
206
208
|
if len(n_value) == 1:
|
|
@@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic):
|
|
|
224
226
|
return result, ""
|
|
225
227
|
|
|
226
228
|
|
|
229
|
+
class GetEuclideanDistance(TensorComparisonBasic):
|
|
230
|
+
"""计算欧式距离"""
|
|
231
|
+
|
|
232
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
233
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
234
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
235
|
+
|
|
236
|
+
distance = np.linalg.norm(n_value - b_value, ord=2)
|
|
237
|
+
|
|
238
|
+
return distance, ""
|
|
239
|
+
|
|
240
|
+
|
|
227
241
|
class GetMaxAbsErr(TensorComparisonBasic):
|
|
228
242
|
"""计算最大绝对误差"""
|
|
229
|
-
|
|
243
|
+
|
|
244
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
230
245
|
temp_res = n_value - b_value
|
|
231
246
|
max_value = np.max(np.abs(temp_res))
|
|
232
247
|
if np.isnan(max_value):
|
|
@@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic):
|
|
|
237
252
|
|
|
238
253
|
class GetMaxRelativeErr(TensorComparisonBasic):
|
|
239
254
|
"""计算最大相对误差"""
|
|
240
|
-
|
|
255
|
+
|
|
256
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
241
257
|
max_relative_err = np.max(np.abs(relative_err))
|
|
242
258
|
if np.isnan(max_relative_err):
|
|
243
259
|
msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
|
|
@@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic):
|
|
|
247
263
|
|
|
248
264
|
class GetErrRatio(TensorComparisonBasic):
|
|
249
265
|
"""计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
|
|
266
|
+
|
|
250
267
|
def __init__(self, threshold):
|
|
251
268
|
self.threshold = threshold
|
|
252
269
|
|
|
253
|
-
def apply(self, n_value, b_value, relative_err):
|
|
254
|
-
if
|
|
255
|
-
return CompareConst.UNSUPPORTED,
|
|
270
|
+
def apply(self, n_value, b_value, relative_err, err_msg):
|
|
271
|
+
if "This is type of 0-d tensor" in err_msg:
|
|
272
|
+
return CompareConst.UNSUPPORTED, err_msg
|
|
256
273
|
|
|
257
274
|
if not np.size(relative_err):
|
|
258
275
|
return CompareConst.NAN, ""
|
|
@@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic):
|
|
|
264
281
|
class CompareOps:
|
|
265
282
|
compare_ops = {
|
|
266
283
|
"cosine_similarity": GetCosineSimilarity(),
|
|
284
|
+
"euclidean_distance": GetEuclideanDistance(),
|
|
267
285
|
"max_abs_error": GetMaxAbsErr(),
|
|
268
286
|
"max_relative_error": GetMaxRelativeErr(),
|
|
269
287
|
"one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
|
|
@@ -295,7 +313,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
|
|
|
295
313
|
n_value, b_value = reshape_value(n_value, b_value)
|
|
296
314
|
|
|
297
315
|
for op in CompareOps.compare_ops.values():
|
|
298
|
-
result, msg = op.apply(n_value, b_value, relative_err)
|
|
316
|
+
result, msg = op.apply(n_value, b_value, relative_err, err_msg)
|
|
299
317
|
result_list.append(result)
|
|
300
318
|
err_msg += msg
|
|
301
319
|
return result_list, err_msg
|