mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,51 +15,28 @@
15
15
 
16
16
  import multiprocessing
17
17
  from dataclasses import dataclass
18
+ from functools import partial
19
+
18
20
  import pandas as pd
19
21
  from tqdm import tqdm
22
+
20
23
  from msprobe.core.common.log import logger
21
24
  from msprobe.core.common.utils import CompareException
22
25
  from msprobe.core.common.const import CompareConst
26
+ from msprobe.core.common.exceptions import FileCheckException
27
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
28
+ from msprobe.core.compare.config import ModeConfig
23
29
 
24
30
 
25
- def _handle_multi_process(func, input_parma, result_df, lock):
26
- process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
27
- op_name_mapping_dict = read_dump_data(result_df)
28
-
29
- df_chunk_size = len(result_df) // process_num
30
- if df_chunk_size > 0:
31
- df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
32
- else:
33
- df_chunks = [result_df]
34
-
35
- results = []
36
- pool = multiprocessing.Pool(process_num)
37
-
38
- def err_call(args):
39
- logger.error('multiprocess compare failed! Reason: {}'.format(args))
40
- try:
41
- pool.terminate()
42
- except OSError as e:
43
- logger.error("pool terminate failed")
44
-
45
- progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
46
-
47
- def update_progress(size, progress_lock):
48
- with progress_lock:
49
- progress_bar.update(size)
50
-
51
- for process_idx, df_chunk in enumerate(df_chunks):
52
- idx = df_chunk_size * process_idx
53
- chunk_size = len(df_chunk)
54
- result = pool.apply_async(func,
55
- args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
56
- error_callback=err_call,
57
- callback=update_progress(chunk_size, lock))
58
- results.append(result)
59
- final_results = [r.get() for r in results]
60
- pool.close()
61
- pool.join()
62
- return pd.concat(final_results, ignore_index=True)
31
+ @dataclass
32
+ class ComparisonResult:
33
+ cos_result: list
34
+ euc_dist_result: list
35
+ max_err_result: list
36
+ max_relative_err_result: list
37
+ one_thousand_err_ratio_result: list
38
+ five_thousand_err_ratio_result: list
39
+ err_msgs: list
63
40
 
64
41
 
65
42
  def _ms_graph_handle_multi_process(func, result_df, mode):
@@ -76,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
76
53
  def err_call(args):
77
54
  logger.error('multiprocess compare failed! Reason: {}'.format(args))
78
55
  try:
79
- pool.terminate()
56
+ pool.close()
80
57
  except OSError as e:
81
- logger.error("pool terminate failed")
58
+ logger.error(f'pool terminate failed: {str(e)}')
82
59
 
83
60
  for df_chunk in df_chunks:
84
61
  result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
@@ -89,72 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
89
66
  return pd.concat(final_results, ignore_index=True)
90
67
 
91
68
 
92
- def read_dump_data(result_df):
93
- try:
94
- npu_dump_name_list = result_df.iloc[0:, 0].tolist()
95
- npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
96
- op_name_mapping_dict = {}
97
- for index, _ in enumerate(npu_dump_name_list):
98
- npu_dump_name = npu_dump_name_list[index]
99
- npu_dump_tensor = npu_dump_tensor_list[index]
100
- op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
101
- return op_name_mapping_dict
102
- except ValueError as e:
103
- logger.error('result dataframe is not found.')
104
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
105
- except IndexError as e:
106
- logger.error('result dataframe elements can not be access.')
107
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
108
-
109
-
110
- @dataclass
111
- class ComparisonResult:
112
- cos_result: list
113
- max_err_result: list
114
- max_relative_err_result: list
115
- err_msgs: list
116
- one_thousand_err_ratio_result: list
117
- five_thousand_err_ratio_result: list
118
-
119
-
120
- def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
121
- """
122
- Save comparison results into the result DataFrame with thread safety.
123
- Args:
124
- offset: offset for index
125
- result: data struct of ComparisonResult
126
- result_df: result of DataFrame
127
- lock: thread lock
128
-
129
- Returns:
130
- comparison results in DataFrame
131
- """
132
-
133
- lock.acquire()
134
- try:
135
- for i, _ in enumerate(result.cos_result):
136
- process_index = i + offset
137
- result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
138
- result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
139
- result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
140
- result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
141
- result_df.loc[process_index, CompareConst.ACCURACY] = (
142
- check_accuracy(result.cos_result[i], result.max_err_result[i]))
143
- result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
144
- result.one_thousand_err_ratio_result)[i]
145
- result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
146
- result.five_thousand_err_ratio_result)[i]
147
- return result_df
148
- except ValueError as e:
149
- logger.error('result dataframe is not found.')
150
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
151
- except IndexError as e:
152
- logger.error('result dataframe elements can not be access.')
153
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
154
- finally:
155
- lock.release()
156
-
157
-
158
69
  def check_accuracy(cos, max_abs_err):
159
70
  if cos == CompareConst.SHAPE_UNMATCH:
160
71
  return CompareConst.ACCURACY_CHECK_UNMATCH
@@ -172,3 +83,212 @@ def check_accuracy(cos, max_abs_err):
172
83
  if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
173
84
  return CompareConst.ACCURACY_CHECK_NO
174
85
  return CompareConst.ACCURACY_CHECK_YES
86
+
87
+
88
+ class CompareRealData:
89
+ def __init__(self, file_reader, mode_config: ModeConfig, cross_frame):
90
+ self.file_reader = file_reader
91
+ self.mode_config = mode_config
92
+ self.cross_frame = cross_frame
93
+
94
+ @staticmethod
95
+ def read_dump_data(result_df):
96
+ try:
97
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
98
+ dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
99
+ op_name_mapping_dict = {}
100
+ for index, npu_dump_name in enumerate(npu_dump_name_list):
101
+ dump_tensor_pair = dump_tensor_pair_list[index]
102
+ op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
103
+ return op_name_mapping_dict
104
+ except ValueError as e:
105
+ logger.error('result dataframe is not found.')
106
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
107
+ except IndexError as e:
108
+ logger.error('result dataframe elements can not be access.')
109
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
110
+
111
+ @staticmethod
112
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
113
+ """
114
+ Save comparison results into the result DataFrame with thread safety.
115
+ Args:
116
+ offset: offset for index
117
+ result: data struct of ComparisonResult
118
+ result_df: result of DataFrame
119
+ lock: thread lock
120
+
121
+ Returns:
122
+ comparison results in DataFrame
123
+ """
124
+
125
+ lock.acquire()
126
+ try:
127
+ for i, cos_item in enumerate(result.cos_result):
128
+ process_index = i + offset
129
+ result_df.loc[process_index, CompareConst.COSINE] = cos_item
130
+ result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
131
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
132
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
133
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
134
+ result.one_thousand_err_ratio_result)[i]
135
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
136
+ result.five_thousand_err_ratio_result)[i]
137
+ result_df.loc[process_index, CompareConst.ACCURACY] = (
138
+ check_accuracy(result.cos_result[i], result.max_err_result[i]))
139
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
140
+ return result_df
141
+ except ValueError as e:
142
+ logger.error('result dataframe is not found.')
143
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
144
+ except IndexError as e:
145
+ logger.error('result dataframe elements can not be access.')
146
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
147
+ finally:
148
+ lock.release()
149
+
150
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
151
+ """
152
+ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
153
+ :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
154
+ :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
155
+ :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
156
+ :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
157
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
158
+ 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
159
+ """
160
+ error_file, relative_err, error_flag = None, None, False
161
+
162
+ data_name_pair = op_name_mapping_dict.get(npu_op_name)
163
+ npu_data_name = data_name_pair[0]
164
+ bench_data_name = data_name_pair[1]
165
+
166
+ if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据
167
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
168
+ elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据
169
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
170
+ error_file = 'no_bench_data'
171
+ elif str(bench_data_name) == CompareConst.N_A: # bench没匹配
172
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
173
+ error_file = None
174
+ else:
175
+ npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR)
176
+ bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR)
177
+ try:
178
+ n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name,
179
+ self.cross_frame)
180
+ except IOError as error:
181
+ error_file = error.filename
182
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
183
+ error_flag = True
184
+ except (FileCheckException, CompareException):
185
+ error_file = data_name_pair
186
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
187
+ error_flag = True
188
+
189
+ # 通过n_value, b_value同时得到错误标志和错误信息
190
+ n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
191
+ error_flag=error_flag, error_file=error_file)
192
+
193
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
194
+
195
+ if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
196
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
197
+ result_list.append(err_msg)
198
+ return result_list
199
+
200
+ def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
201
+ cos_result = []
202
+ euc_dist_result = []
203
+ max_err_result = []
204
+ max_relative_err_result = []
205
+ one_thousand_err_ratio_result = []
206
+ five_thousand_err_ratio_result = []
207
+ err_mess = []
208
+
209
+ is_print_compare_log = input_param.get("is_print_compare_log")
210
+
211
+ for i in range(len(result_df)):
212
+ npu_op_name = result_df.iloc[i, 0]
213
+ bench_op_name = result_df.iloc[i, 1]
214
+ if is_print_compare_log:
215
+ logger.info("start compare: {}".format(npu_op_name))
216
+
217
+ cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
218
+ = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
219
+
220
+ if is_print_compare_log:
221
+ logger.info(
222
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
223
+ one_thousand_err_ratio {}, "
224
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
225
+ err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
226
+ cos_result.append(cos_sim)
227
+ euc_dist_result.append(euc_dist)
228
+ max_err_result.append(max_abs_err)
229
+ max_relative_err_result.append(max_relative_err)
230
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
231
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
232
+ err_mess.append(err_msg)
233
+
234
+ cr = ComparisonResult(
235
+ cos_result=cos_result,
236
+ euc_dist_result=euc_dist_result,
237
+ max_err_result=max_err_result,
238
+ max_relative_err_result=max_relative_err_result,
239
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
240
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result,
241
+ err_msgs=err_mess
242
+ )
243
+
244
+ return self._save_cmp_result(idx, cr, result_df, lock)
245
+
246
+ def do_multi_process(self, input_param, result_df):
247
+ try:
248
+ result_df = self._handle_multi_process(self.compare_ops, input_param, result_df,
249
+ multiprocessing.Manager().RLock())
250
+ return result_df
251
+ except ValueError as e:
252
+ logger.error('result dataframe is not found.')
253
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
254
+
255
+ def _handle_multi_process(self, func, input_param, result_df, lock):
256
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
257
+ op_name_mapping_dict = self.read_dump_data(result_df)
258
+
259
+ df_chunk_size = len(result_df) // process_num
260
+ if df_chunk_size > 0:
261
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
262
+ else:
263
+ df_chunks = [result_df]
264
+
265
+ results = []
266
+ pool = multiprocessing.Pool(process_num)
267
+
268
+ def err_call(args):
269
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
270
+ try:
271
+ pool.close()
272
+ except OSError:
273
+ logger.error("pool terminate failed")
274
+
275
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
276
+
277
+ def update_progress(size, progress_lock, extra_param=None):
278
+ with progress_lock:
279
+ progress_bar.update(size)
280
+
281
+ for process_idx, df_chunk in enumerate(df_chunks):
282
+ idx = df_chunk_size * process_idx
283
+ chunk_size = len(df_chunk)
284
+ result = pool.apply_async(func,
285
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
286
+ error_callback=err_call,
287
+ callback=partial(update_progress, chunk_size, lock)
288
+ )
289
+ results.append(result)
290
+
291
+ final_results = [r.get() for r in results]
292
+ pool.close()
293
+ pool.join()
294
+ return pd.concat(final_results, ignore_index=True)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -59,7 +59,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
59
59
  if error_file == "no_bench_data":
60
60
  err_msg = "Bench does not have data file."
61
61
  elif error_file:
62
- err_msg = f"Dump file: {error_file} not found."
62
+ err_msg = f"Dump file: {error_file} not found or read failed."
63
63
  else:
64
64
  err_msg = CompareConst.NO_BENCH
65
65
  error_flag = True
@@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
70
70
  error_flag = True
71
71
  return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
72
72
  if not n_value.shape: # 判断数据是否为0维张量
73
- err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
73
+ err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', "
74
74
  f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
75
75
  error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
76
76
  return n_value, b_value, error_flag, err_msg
@@ -168,8 +168,9 @@ def statistics_data_check(result_dict):
168
168
 
169
169
  class TensorComparisonBasic(abc.ABC):
170
170
  """NPU和bench中npy数据的比较模板"""
171
+
171
172
  @abc.abstractmethod
172
- def apply(self, n_value, b_value, relative_err):
173
+ def apply(self, n_value, b_value, relative_err, err_msg):
173
174
  raise NotImplementedError
174
175
 
175
176
 
@@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value):
190
191
 
191
192
  class GetCosineSimilarity(TensorComparisonBasic):
192
193
  """计算cosine相似度"""
194
+
193
195
  @staticmethod
194
196
  def correct_data(result):
195
197
  if result == CompareConst.NAN:
@@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic):
198
200
  return round(float(result), 6)
199
201
  return result
200
202
 
201
- def apply(self, n_value, b_value, relative_err):
202
- if not n_value.shape:
203
- return CompareConst.UNSUPPORTED, ""
203
+ def apply(self, n_value, b_value, relative_err, err_msg):
204
+ if "This is type of 0-d tensor" in err_msg:
205
+ return CompareConst.UNSUPPORTED, err_msg
204
206
 
205
207
  with np.errstate(divide="ignore", invalid="ignore"):
206
208
  if len(n_value) == 1:
@@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic):
224
226
  return result, ""
225
227
 
226
228
 
229
+ class GetEuclideanDistance(TensorComparisonBasic):
230
+ """计算欧式距离"""
231
+
232
+ def apply(self, n_value, b_value, relative_err, err_msg):
233
+ if "This is type of 0-d tensor" in err_msg:
234
+ return CompareConst.UNSUPPORTED, err_msg
235
+
236
+ distance = np.linalg.norm(n_value - b_value, ord=2)
237
+
238
+ return distance, ""
239
+
240
+
227
241
  class GetMaxAbsErr(TensorComparisonBasic):
228
242
  """计算最大绝对误差"""
229
- def apply(self, n_value, b_value, relative_err):
243
+
244
+ def apply(self, n_value, b_value, relative_err, err_msg):
230
245
  temp_res = n_value - b_value
231
246
  max_value = np.max(np.abs(temp_res))
232
247
  if np.isnan(max_value):
@@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic):
237
252
 
238
253
  class GetMaxRelativeErr(TensorComparisonBasic):
239
254
  """计算最大相对误差"""
240
- def apply(self, n_value, b_value, relative_err):
255
+
256
+ def apply(self, n_value, b_value, relative_err, err_msg):
241
257
  max_relative_err = np.max(np.abs(relative_err))
242
258
  if np.isnan(max_relative_err):
243
259
  msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
@@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic):
247
263
 
248
264
  class GetErrRatio(TensorComparisonBasic):
249
265
  """计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
266
+
250
267
  def __init__(self, threshold):
251
268
  self.threshold = threshold
252
269
 
253
- def apply(self, n_value, b_value, relative_err):
254
- if not n_value.shape:
255
- return CompareConst.UNSUPPORTED, ""
270
+ def apply(self, n_value, b_value, relative_err, err_msg):
271
+ if "This is type of 0-d tensor" in err_msg:
272
+ return CompareConst.UNSUPPORTED, err_msg
256
273
 
257
274
  if not np.size(relative_err):
258
275
  return CompareConst.NAN, ""
@@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic):
264
281
  class CompareOps:
265
282
  compare_ops = {
266
283
  "cosine_similarity": GetCosineSimilarity(),
284
+ "euclidean_distance": GetEuclideanDistance(),
267
285
  "max_abs_error": GetMaxAbsErr(),
268
286
  "max_relative_error": GetMaxRelativeErr(),
269
287
  "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
@@ -272,10 +290,8 @@ class CompareOps:
272
290
 
273
291
 
274
292
  def error_value_process(n_value):
275
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
293
+ if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE]:
276
294
  return CompareConst.UNSUPPORTED, ""
277
- if n_value == CompareConst.NONE:
278
- return 0, ""
279
295
  if n_value == CompareConst.SHAPE_UNMATCH:
280
296
  return CompareConst.SHAPE_UNMATCH, ""
281
297
  if n_value == CompareConst.NAN:
@@ -295,7 +311,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
295
311
  n_value, b_value = reshape_value(n_value, b_value)
296
312
 
297
313
  for op in CompareOps.compare_ops.values():
298
- result, msg = op.apply(n_value, b_value, relative_err)
314
+ result, msg = op.apply(n_value, b_value, relative_err, err_msg)
299
315
  result_list.append(result)
300
316
  err_msg += msg
301
317
  return result_list, err_msg