mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -282,6 +282,8 @@ class Comparator:
282
282
  result = []
283
283
  bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
284
284
  for ms_op_name, bench_op_name in self.data_mapping_dict.items():
285
+ check_op_str_pattern_valid(ms_op_name)
286
+ check_op_str_pattern_valid(bench_op_name)
285
287
  if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
286
288
  npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
287
289
  bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
@@ -311,9 +313,9 @@ class Comparator:
311
313
  ]
312
314
 
313
315
  if self.dump_mode == Const.SUMMARY:
314
- result_item = base_result_item + [" "] * 8
316
+ result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标
315
317
  else:
316
- result_item = base_result_item + [" "] * 5
318
+ result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标
317
319
 
318
320
  npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
319
321
  result_item.extend(npu_summary_data)
@@ -329,8 +331,11 @@ class Comparator:
329
331
  else:
330
332
  result_item.append(CompareConst.NONE)
331
333
  if self.dump_mode == Const.ALL:
332
- result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
334
+ ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None)
335
+ pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None)
336
+ result_item.append([ms_data_name, pt_data_name])
333
337
  result.append(result_item)
338
+ logger.info(f"{ms_op_name}, {bench_op_name} compared.")
334
339
  elif ms_op_name not in npu_ops_all:
335
340
  logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
336
341
  elif bench_op_name not in npu_ops_all:
@@ -349,47 +354,48 @@ class Comparator:
349
354
  result_df = self.make_result_table(result)
350
355
  return result_df
351
356
 
352
- def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
357
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
353
358
  """
354
359
  :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
355
360
  :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
356
361
  :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
357
362
  :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
358
- :param bench_data: bench的dump数据中"data"字段
359
363
  :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
360
- 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
364
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
361
365
  最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
362
366
  """
363
- npu_bench_name_list = op_name_mapping_dict[npu_op_name]
364
- data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
365
367
  error_file, relative_err, error_flag = None, None, False
366
- bench_data_name = get_bench_data_name(bench_op_name, bench_data)
367
- if data_name == '-1' or data_name == -1: # 没有真实数据路径
368
- n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
369
- error_flag = True
370
- elif not bench_data_name:
368
+
369
+ data_name_pair = op_name_mapping_dict.get(npu_op_name)
370
+ npu_data_name = data_name_pair[0]
371
+ bench_data_name = data_name_pair[1]
372
+
373
+ if str(npu_data_name) == '-1': # 没有npu真实数据
374
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
375
+ elif str(bench_data_name) == '-1': # 没有bench真实数据
371
376
  n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
372
377
  error_file = 'no_bench_data'
373
378
  else:
379
+ npu_dir = input_param.get("npu_dump_data_dir")
380
+ bench_dir = input_param.get("bench_dump_data_dir")
374
381
  try:
375
- read_npy_data = getattr(self, "read_npy_data")
376
382
  frame_name = getattr(self, "frame_name")
383
+ read_npy_data = getattr(self, "read_npy_data")
377
384
  if frame_name == "MSComparator":
378
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
385
+ n_value = read_npy_data(npu_dir, npu_data_name)
379
386
  if self.cross_frame:
380
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
381
- load_pt_file=True)
387
+ b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True)
382
388
  else:
383
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
389
+ b_value = read_npy_data(bench_dir, bench_data_name)
384
390
  else:
385
- n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
386
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
391
+ n_value = read_npy_data(npu_dir, npu_data_name)
392
+ b_value = read_npy_data(bench_dir, bench_data_name)
387
393
  except IOError as error:
388
394
  error_file = error.filename
389
395
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
390
396
  error_flag = True
391
397
  except (FileCheckException, CompareException):
392
- error_file = data_name
398
+ error_file = npu_data_name
393
399
  n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
394
400
  error_flag = True
395
401
 
@@ -427,7 +433,9 @@ class Comparator:
427
433
  logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
428
434
  file_name = add_time_with_xlsx("compare_result" + suffix)
429
435
  file_path = os.path.join(os.path.realpath(output_path), file_name)
430
- remove_path(file_path)
436
+ if os.path.exists(file_path):
437
+ logger.warning(f"{file_path} will be deleted.")
438
+ remove_path(file_path)
431
439
  highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
432
440
 
433
441
  npu_json = input_param.get("npu_json_path")
@@ -456,21 +464,23 @@ class Comparator:
456
464
 
457
465
  def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
458
466
  cos_result = []
467
+ euc_dist_result = []
459
468
  max_err_result = []
460
469
  max_relative_err_result = []
461
- err_mess = []
462
470
  one_thousand_err_ratio_result = []
463
471
  five_thousand_err_ratio_result = []
472
+ err_mess = []
473
+
464
474
  is_print_compare_log = input_param.get("is_print_compare_log")
465
- bench_data = load_json(input_param.get("bench_json_path")).get('data')
475
+
466
476
  for i in range(len(result_df)):
467
477
  npu_op_name = result_df.iloc[i, 0]
468
478
  bench_op_name = result_df.iloc[i, 1]
469
479
  if is_print_compare_log:
470
480
  logger.info("start compare: {}".format(npu_op_name))
471
481
 
472
- cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
473
- self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
482
+ cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
483
+ = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
474
484
 
475
485
  if is_print_compare_log:
476
486
  logger.info(
@@ -479,71 +489,30 @@ class Comparator:
479
489
  "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
480
490
  err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
481
491
  cos_result.append(cos_sim)
492
+ euc_dist_result.append(euc_dist)
482
493
  max_err_result.append(max_abs_err)
483
494
  max_relative_err_result.append(max_relative_err)
484
- err_mess.append(err_msg)
485
495
  one_thousand_err_ratio_result.append(one_thousand_err_ratio)
486
496
  five_thousand_err_ratio_result.append(five_thousand_err_ratio)
497
+ err_mess.append(err_msg)
487
498
 
488
499
  cr = ComparisonResult(
489
500
  cos_result=cos_result,
501
+ euc_dist_result=euc_dist_result,
490
502
  max_err_result=max_err_result,
491
503
  max_relative_err_result=max_relative_err_result,
492
- err_msgs=err_mess,
493
504
  one_thousand_err_ratio_result=one_thousand_err_ratio_result,
494
- five_thousand_err_ratio_result=five_thousand_err_ratio_result
505
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result,
506
+ err_msgs=err_mess
495
507
  )
496
508
 
497
509
  return _save_cmp_result(idx, cr, result_df, lock)
498
510
 
499
- def do_multi_process(self, input_parma, result_df):
511
+ def do_multi_process(self, input_param, result_df):
500
512
  try:
501
- result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
513
+ result_df = _handle_multi_process(self.compare_ops, input_param, result_df,
502
514
  multiprocessing.Manager().RLock())
503
515
  return result_df
504
516
  except ValueError as e:
505
517
  logger.error('result dataframe is not found.')
506
518
  raise CompareException(CompareException.INVALID_DATA_ERROR) from e
507
-
508
-
509
- def get_bench_data_name(bench_op_name, bench_data):
510
- bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
511
- if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
512
- bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
513
- else:
514
- bench_data_bundle = bench_data.get(bench_name_list[0], {})
515
- if not bench_data_bundle or len(bench_name_list) < 3:
516
- return None
517
- layers = bench_name_list[2].split(Const.SEP)
518
-
519
- def _get(key, container):
520
- if isinstance(container, dict):
521
- return container.get(key)
522
- if isinstance(container, list):
523
- try:
524
- return container[int(key)]
525
- except (ValueError, IndexError):
526
- return None
527
- return None
528
-
529
- def get_by_layer(container, params_grad=False):
530
- data = container
531
- # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
532
- if params_grad:
533
- layers.append('0')
534
- for layer in layers:
535
- data = _get(layer, data)
536
- return _get(CompareConst.DATA_NAME.lower(), data)
537
-
538
- if Const.INPUT == bench_name_list[1]:
539
- return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
540
- elif Const.KWARGS == bench_name_list[1]:
541
- return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
542
- elif Const.OUTPUT == bench_name_list[1]:
543
- return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
544
- elif Const.PARAMS == bench_name_list[1]:
545
- return get_by_layer(bench_data_bundle.get(Const.PARAMS))
546
- elif Const.PARAMS_GRAD == bench_name_list[1]:
547
- return get_by_layer(bench_data_bundle, params_grad=True)
548
- else:
549
- return None
@@ -82,12 +82,8 @@ def check_type_shape_match(npu_struct, bench_struct):
82
82
  f'should both be 2, please check!')
83
83
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
84
84
  shape_match = npu_shape == bench_shape
85
- type_match = npu_type == bench_type
86
- if not type_match:
87
- if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
88
- type_match = True
89
- else:
90
- type_match = False
85
+ type_match = ((npu_type == bench_type) or
86
+ any(npu_type in group and bench_type in group for group in CompareConst.DTYPE_MATCH_GROUPS))
91
87
  struct_match = shape_match and type_match
92
88
  if not struct_match:
93
89
  return False
@@ -146,11 +146,13 @@ class HighlightRules:
146
146
  }
147
147
 
148
148
  # 用于比较输入和输出的规则
149
+ # 真实数据检查规则
149
150
  compare_rules = {
150
151
  "check_order_magnitude": CheckOrderMagnitude(),
151
152
  "check_one_thousand_error": CheckOneThousandErrorRatio(),
152
153
  "check_cosine_similarity": CheckCosineSimilarity()
153
154
  }
155
+ # 统计量数据检查规则
154
156
  summary_compare_rules = {
155
157
  "check_order_magnitude": CheckOrderMagnitude(),
156
158
  "check_max_relative_diff": CheckMaxRelativeDiff(),
@@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml,
23
23
  get_stack_construct_by_dump_json_path)
24
24
  from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
25
25
  from msprobe.core.compare.utils import read_op, reorder_op_name_list
26
-
26
+ from msprobe.core.common.decorator import recursion_depth_decorator
27
27
 
28
28
 
29
29
  class LayerTrie:
@@ -71,6 +71,7 @@ class LayerTrie:
71
71
  file_path = os.path.join(os.path.realpath(output_path), file_name)
72
72
  save_yaml(file_path, result)
73
73
 
74
+ @recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100)
74
75
  def convert_to_dict(self, node):
75
76
  result = {}
76
77
  result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +21,8 @@ from functools import partial
21
21
  import pandas as pd
22
22
  from tqdm import tqdm
23
23
 
24
- from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
24
+ from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \
25
+ remove_path
25
26
  from msprobe.core.common.const import FileCheckConst, Const, CompareConst
26
27
  from msprobe.core.common.utils import CompareException, add_time_with_xlsx
27
28
  from msprobe.core.compare.utils import table_value_is_valid
@@ -63,6 +64,7 @@ def get_result_path(input_dir):
63
64
  for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
64
65
  filt_compare_result_path_list = []
65
66
  for file_path in compare_result_path_list:
67
+ FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
66
68
  file_name = os.path.basename(file_path)
67
69
  if check_compare_result_name(file_name):
68
70
  compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
@@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co
329
331
  for i, df in enumerate(merge_df_list):
330
332
  # merge_df_list中df与compare_index_list中compare_index一一对应
331
333
  final_result_df_list.append((df, compare_index_list[i]))
334
+
335
+ if os.path.exists(output_path):
336
+ logger.warning(f"{output_path} will be deleted.")
337
+ remove_path(output_path)
332
338
  save_excel(output_path, final_result_df_list)
333
339
  logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
334
340
 
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,17 @@
15
15
 
16
16
  import multiprocessing
17
17
  from dataclasses import dataclass
18
+ from functools import partial
19
+
18
20
  import pandas as pd
19
21
  from tqdm import tqdm
22
+
20
23
  from msprobe.core.common.log import logger
21
24
  from msprobe.core.common.utils import CompareException
22
25
  from msprobe.core.common.const import CompareConst
23
26
 
24
27
 
25
- def _handle_multi_process(func, input_parma, result_df, lock):
28
+ def _handle_multi_process(func, input_param, result_df, lock):
26
29
  process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
27
30
  op_name_mapping_dict = read_dump_data(result_df)
28
31
 
@@ -44,7 +47,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
44
47
 
45
48
  progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
46
49
 
47
- def update_progress(size, progress_lock):
50
+ def update_progress(size, progress_lock, extra_param=None):
48
51
  with progress_lock:
49
52
  progress_bar.update(size)
50
53
 
@@ -52,10 +55,12 @@ def _handle_multi_process(func, input_parma, result_df, lock):
52
55
  idx = df_chunk_size * process_idx
53
56
  chunk_size = len(df_chunk)
54
57
  result = pool.apply_async(func,
55
- args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
58
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
56
59
  error_callback=err_call,
57
- callback=update_progress(chunk_size, lock))
60
+ callback=partial(update_progress, chunk_size, lock)
61
+ )
58
62
  results.append(result)
63
+
59
64
  final_results = [r.get() for r in results]
60
65
  pool.close()
61
66
  pool.join()
@@ -92,12 +97,12 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
92
97
  def read_dump_data(result_df):
93
98
  try:
94
99
  npu_dump_name_list = result_df.iloc[0:, 0].tolist()
95
- npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
100
+ dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
96
101
  op_name_mapping_dict = {}
97
102
  for index, _ in enumerate(npu_dump_name_list):
98
103
  npu_dump_name = npu_dump_name_list[index]
99
- npu_dump_tensor = npu_dump_tensor_list[index]
100
- op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
104
+ dump_tensor_pair = dump_tensor_pair_list[index]
105
+ op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
101
106
  return op_name_mapping_dict
102
107
  except ValueError as e:
103
108
  logger.error('result dataframe is not found.')
@@ -110,11 +115,12 @@ def read_dump_data(result_df):
110
115
  @dataclass
111
116
  class ComparisonResult:
112
117
  cos_result: list
118
+ euc_dist_result: list
113
119
  max_err_result: list
114
120
  max_relative_err_result: list
115
- err_msgs: list
116
121
  one_thousand_err_ratio_result: list
117
122
  five_thousand_err_ratio_result: list
123
+ err_msgs: list
118
124
 
119
125
 
120
126
  def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
@@ -135,15 +141,16 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
135
141
  for i, _ in enumerate(result.cos_result):
136
142
  process_index = i + offset
137
143
  result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
144
+ result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
138
145
  result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
139
146
  result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
140
- result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
141
- result_df.loc[process_index, CompareConst.ACCURACY] = (
142
- check_accuracy(result.cos_result[i], result.max_err_result[i]))
143
147
  result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
144
148
  result.one_thousand_err_ratio_result)[i]
145
149
  result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
146
150
  result.five_thousand_err_ratio_result)[i]
151
+ result_df.loc[process_index, CompareConst.ACCURACY] = (
152
+ check_accuracy(result.cos_result[i], result.max_err_result[i]))
153
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
147
154
  return result_df
148
155
  except ValueError as e:
149
156
  logger.error('result dataframe is not found.')
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
70
70
  error_flag = True
71
71
  return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
72
72
  if not n_value.shape: # 判断数据是否为0维张量
73
- err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
73
+ err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', "
74
74
  f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
75
75
  error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
76
76
  return n_value, b_value, error_flag, err_msg
@@ -168,8 +168,9 @@ def statistics_data_check(result_dict):
168
168
 
169
169
  class TensorComparisonBasic(abc.ABC):
170
170
  """NPU和bench中npy数据的比较模板"""
171
+
171
172
  @abc.abstractmethod
172
- def apply(self, n_value, b_value, relative_err):
173
+ def apply(self, n_value, b_value, relative_err, err_msg):
173
174
  raise NotImplementedError
174
175
 
175
176
 
@@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value):
190
191
 
191
192
  class GetCosineSimilarity(TensorComparisonBasic):
192
193
  """计算cosine相似度"""
194
+
193
195
  @staticmethod
194
196
  def correct_data(result):
195
197
  if result == CompareConst.NAN:
@@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic):
198
200
  return round(float(result), 6)
199
201
  return result
200
202
 
201
- def apply(self, n_value, b_value, relative_err):
202
- if not n_value.shape:
203
- return CompareConst.UNSUPPORTED, ""
203
+ def apply(self, n_value, b_value, relative_err, err_msg):
204
+ if "This is type of 0-d tensor" in err_msg:
205
+ return CompareConst.UNSUPPORTED, err_msg
204
206
 
205
207
  with np.errstate(divide="ignore", invalid="ignore"):
206
208
  if len(n_value) == 1:
@@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic):
224
226
  return result, ""
225
227
 
226
228
 
229
+ class GetEuclideanDistance(TensorComparisonBasic):
230
+ """计算欧式距离"""
231
+
232
+ def apply(self, n_value, b_value, relative_err, err_msg):
233
+ if "This is type of 0-d tensor" in err_msg:
234
+ return CompareConst.UNSUPPORTED, err_msg
235
+
236
+ distance = np.linalg.norm(n_value - b_value, ord=2)
237
+
238
+ return distance, ""
239
+
240
+
227
241
  class GetMaxAbsErr(TensorComparisonBasic):
228
242
  """计算最大绝对误差"""
229
- def apply(self, n_value, b_value, relative_err):
243
+
244
+ def apply(self, n_value, b_value, relative_err, err_msg):
230
245
  temp_res = n_value - b_value
231
246
  max_value = np.max(np.abs(temp_res))
232
247
  if np.isnan(max_value):
@@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic):
237
252
 
238
253
  class GetMaxRelativeErr(TensorComparisonBasic):
239
254
  """计算最大相对误差"""
240
- def apply(self, n_value, b_value, relative_err):
255
+
256
+ def apply(self, n_value, b_value, relative_err, err_msg):
241
257
  max_relative_err = np.max(np.abs(relative_err))
242
258
  if np.isnan(max_relative_err):
243
259
  msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
@@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic):
247
263
 
248
264
  class GetErrRatio(TensorComparisonBasic):
249
265
  """计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
266
+
250
267
  def __init__(self, threshold):
251
268
  self.threshold = threshold
252
269
 
253
- def apply(self, n_value, b_value, relative_err):
254
- if not n_value.shape:
255
- return CompareConst.UNSUPPORTED, ""
270
+ def apply(self, n_value, b_value, relative_err, err_msg):
271
+ if "This is type of 0-d tensor" in err_msg:
272
+ return CompareConst.UNSUPPORTED, err_msg
256
273
 
257
274
  if not np.size(relative_err):
258
275
  return CompareConst.NAN, ""
@@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic):
264
281
  class CompareOps:
265
282
  compare_ops = {
266
283
  "cosine_similarity": GetCosineSimilarity(),
284
+ "euclidean_distance": GetEuclideanDistance(),
267
285
  "max_abs_error": GetMaxAbsErr(),
268
286
  "max_relative_error": GetMaxRelativeErr(),
269
287
  "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
@@ -295,7 +313,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
295
313
  n_value, b_value = reshape_value(n_value, b_value)
296
314
 
297
315
  for op in CompareOps.compare_ops.values():
298
- result, msg = op.apply(n_value, b_value, relative_err)
316
+ result, msg = op.apply(n_value, b_value, relative_err, err_msg)
299
317
  result_list.append(result)
300
318
  err_msg += msg
301
319
  return result_list, err_msg
@@ -285,9 +285,9 @@ def result_item_init(n_info, b_info, dump_mode):
285
285
  md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
286
286
  result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
287
287
  elif dump_mode == Const.SUMMARY:
288
- result_item.extend([" "] * 8)
288
+ result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
289
289
  else:
290
- result_item.extend([" "] * 5)
290
+ result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
291
291
  else:
292
292
  err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
293
293
  f"npu_info_struct is {n_info.struct}\n" \
@@ -321,8 +321,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
321
321
  has_stack = npu_stack_info and bench_stack_info
322
322
 
323
323
  if dump_mode == Const.ALL:
324
- npu_data_name = n_dict.get("data_name", None)
325
- bench_data_name = b_dict.get("data_name", None)
324
+ npu_data_name_list = n_dict.get("data_name", None)
325
+ bench_data_name_list = b_dict.get("data_name", None)
326
326
 
327
327
  for index in range(min_len):
328
328
  n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
@@ -353,7 +353,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
353
353
  result_item.append(err_msg)
354
354
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
355
355
  if dump_mode == Const.ALL:
356
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
356
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
357
+ bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
358
+ result_item.append([npu_data_name, bench_data_name])
357
359
 
358
360
  result.append(result_item)
359
361
 
@@ -371,7 +373,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
371
373
  continue
372
374
  result_item = [
373
375
  n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
374
- " ", " ", " ", " ", " "
376
+ " ", " ", " ", " ", " ", " "
375
377
  ]
376
378
  summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
377
379
  result_item.extend(summary_data)
@@ -388,7 +390,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
388
390
  result_item.append(err_msg)
389
391
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
390
392
  if dump_mode == Const.ALL:
391
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
393
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
394
+ result_item.append([npu_data_name, "-1"])
392
395
 
393
396
  result.append(result_item)
394
397
 
@@ -453,9 +456,9 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
453
456
  result.append(result_item)
454
457
  continue
455
458
  if dump_mode == Const.SUMMARY:
456
- result_item.extend([CompareConst.N_A] * 8)
459
+ result_item.extend([CompareConst.N_A] * 8) # 8个统计量数据情况的比对指标
457
460
  if dump_mode == Const.ALL:
458
- result_item.extend([CompareConst.N_A] * 5)
461
+ result_item.extend([CompareConst.N_A] * 6) # 6个真实数据情况的比对指标
459
462
 
460
463
  npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
461
464
  bench_summary_data = [CompareConst.N_A] * 4
@@ -467,7 +470,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
467
470
  result_item.append(err_msg)
468
471
  append_stack_info(result_item, npu_stack_info, index)
469
472
  if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
470
- result_item.extend(["-1"])
473
+ result_item.extend([["-1", "-1"]])
471
474
  result.append(result_item)
472
475
 
473
476
 
@@ -542,10 +545,17 @@ def get_name_and_state(name):
542
545
 
543
546
  state type: input, output, kwargs, parameters, parameters_grad
544
547
  """
548
+ if not isinstance(name, str):
549
+ logger.error(f'Invalid name: {name}, type should be string, please check.')
550
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
551
+
545
552
  if Const.PARAMS_GRAD in name.split(Const.SEP):
546
553
  return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
547
554
 
548
555
  split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
556
+ if len(split) < 3:
557
+ logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
558
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
549
559
  api = f'{split[0]}.{split[1]}.'
550
560
  state_str = split[2]
551
561
  match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)