mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -23,48 +23,20 @@ from tqdm import tqdm
23
23
  from msprobe.core.common.log import logger
24
24
  from msprobe.core.common.utils import CompareException
25
25
  from msprobe.core.common.const import CompareConst
26
+ from msprobe.core.common.exceptions import FileCheckException
27
+ from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
28
+ from msprobe.core.compare.config import ModeConfig
26
29
 
27
30
 
28
- def _handle_multi_process(func, input_param, result_df, lock):
29
- process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
30
- op_name_mapping_dict = read_dump_data(result_df)
31
-
32
- df_chunk_size = len(result_df) // process_num
33
- if df_chunk_size > 0:
34
- df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
35
- else:
36
- df_chunks = [result_df]
37
-
38
- results = []
39
- pool = multiprocessing.Pool(process_num)
40
-
41
- def err_call(args):
42
- logger.error('multiprocess compare failed! Reason: {}'.format(args))
43
- try:
44
- pool.terminate()
45
- except OSError as e:
46
- logger.error("pool terminate failed")
47
-
48
- progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
49
-
50
- def update_progress(size, progress_lock, extra_param=None):
51
- with progress_lock:
52
- progress_bar.update(size)
53
-
54
- for process_idx, df_chunk in enumerate(df_chunks):
55
- idx = df_chunk_size * process_idx
56
- chunk_size = len(df_chunk)
57
- result = pool.apply_async(func,
58
- args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
59
- error_callback=err_call,
60
- callback=partial(update_progress, chunk_size, lock)
61
- )
62
- results.append(result)
63
-
64
- final_results = [r.get() for r in results]
65
- pool.close()
66
- pool.join()
67
- return pd.concat(final_results, ignore_index=True)
31
+ @dataclass
32
+ class ComparisonResult:
33
+ cos_result: list
34
+ euc_dist_result: list
35
+ max_err_result: list
36
+ max_relative_err_result: list
37
+ one_thousand_err_ratio_result: list
38
+ five_thousand_err_ratio_result: list
39
+ err_msgs: list
68
40
 
69
41
 
70
42
  def _ms_graph_handle_multi_process(func, result_df, mode):
@@ -81,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
81
53
  def err_call(args):
82
54
  logger.error('multiprocess compare failed! Reason: {}'.format(args))
83
55
  try:
84
- pool.terminate()
56
+ pool.close()
85
57
  except OSError as e:
86
- logger.error("pool terminate failed")
58
+ logger.error(f'pool terminate failed: {str(e)}')
87
59
 
88
60
  for df_chunk in df_chunks:
89
61
  result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
@@ -94,74 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode):
94
66
  return pd.concat(final_results, ignore_index=True)
95
67
 
96
68
 
97
- def read_dump_data(result_df):
98
- try:
99
- npu_dump_name_list = result_df.iloc[0:, 0].tolist()
100
- dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
101
- op_name_mapping_dict = {}
102
- for index, _ in enumerate(npu_dump_name_list):
103
- npu_dump_name = npu_dump_name_list[index]
104
- dump_tensor_pair = dump_tensor_pair_list[index]
105
- op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
106
- return op_name_mapping_dict
107
- except ValueError as e:
108
- logger.error('result dataframe is not found.')
109
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
110
- except IndexError as e:
111
- logger.error('result dataframe elements can not be access.')
112
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
113
-
114
-
115
- @dataclass
116
- class ComparisonResult:
117
- cos_result: list
118
- euc_dist_result: list
119
- max_err_result: list
120
- max_relative_err_result: list
121
- one_thousand_err_ratio_result: list
122
- five_thousand_err_ratio_result: list
123
- err_msgs: list
124
-
125
-
126
- def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
127
- """
128
- Save comparison results into the result DataFrame with thread safety.
129
- Args:
130
- offset: offset for index
131
- result: data struct of ComparisonResult
132
- result_df: result of DataFrame
133
- lock: thread lock
134
-
135
- Returns:
136
- comparison results in DataFrame
137
- """
138
-
139
- lock.acquire()
140
- try:
141
- for i, _ in enumerate(result.cos_result):
142
- process_index = i + offset
143
- result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
144
- result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
145
- result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
146
- result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
147
- result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
148
- result.one_thousand_err_ratio_result)[i]
149
- result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
150
- result.five_thousand_err_ratio_result)[i]
151
- result_df.loc[process_index, CompareConst.ACCURACY] = (
152
- check_accuracy(result.cos_result[i], result.max_err_result[i]))
153
- result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
154
- return result_df
155
- except ValueError as e:
156
- logger.error('result dataframe is not found.')
157
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
158
- except IndexError as e:
159
- logger.error('result dataframe elements can not be access.')
160
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
161
- finally:
162
- lock.release()
163
-
164
-
165
69
  def check_accuracy(cos, max_abs_err):
166
70
  if cos == CompareConst.SHAPE_UNMATCH:
167
71
  return CompareConst.ACCURACY_CHECK_UNMATCH
@@ -179,3 +83,212 @@ def check_accuracy(cos, max_abs_err):
179
83
  if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
180
84
  return CompareConst.ACCURACY_CHECK_NO
181
85
  return CompareConst.ACCURACY_CHECK_YES
86
+
87
+
88
+ class CompareRealData:
89
+ def __init__(self, file_reader, mode_config: ModeConfig, cross_frame):
90
+ self.file_reader = file_reader
91
+ self.mode_config = mode_config
92
+ self.cross_frame = cross_frame
93
+
94
+ @staticmethod
95
+ def read_dump_data(result_df):
96
+ try:
97
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
98
+ dump_tensor_pair_list = result_df.iloc[0:, -1].tolist()
99
+ op_name_mapping_dict = {}
100
+ for index, npu_dump_name in enumerate(npu_dump_name_list):
101
+ dump_tensor_pair = dump_tensor_pair_list[index]
102
+ op_name_mapping_dict[npu_dump_name] = dump_tensor_pair
103
+ return op_name_mapping_dict
104
+ except ValueError as e:
105
+ logger.error('result dataframe is not found.')
106
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
107
+ except IndexError as e:
108
+ logger.error('result dataframe elements can not be access.')
109
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
110
+
111
+ @staticmethod
112
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
113
+ """
114
+ Save comparison results into the result DataFrame with thread safety.
115
+ Args:
116
+ offset: offset for index
117
+ result: data struct of ComparisonResult
118
+ result_df: result of DataFrame
119
+ lock: thread lock
120
+
121
+ Returns:
122
+ comparison results in DataFrame
123
+ """
124
+
125
+ lock.acquire()
126
+ try:
127
+ for i, cos_item in enumerate(result.cos_result):
128
+ process_index = i + offset
129
+ result_df.loc[process_index, CompareConst.COSINE] = cos_item
130
+ result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i]
131
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
132
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
133
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
134
+ result.one_thousand_err_ratio_result)[i]
135
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
136
+ result.five_thousand_err_ratio_result)[i]
137
+ result_df.loc[process_index, CompareConst.ACCURACY] = (
138
+ check_accuracy(result.cos_result[i], result.max_err_result[i]))
139
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
140
+ return result_df
141
+ except ValueError as e:
142
+ logger.error('result dataframe is not found.')
143
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
144
+ except IndexError as e:
145
+ logger.error('result dataframe elements can not be access.')
146
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
147
+ finally:
148
+ lock.release()
149
+
150
+ def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
151
+ """
152
+ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
153
+ :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
154
+ :param op_name_mapping_dict: op_name和npy或pt文件的映射关系
155
+ :param input_param: npu_json_path/bench_json_path/stack_json_path等参数
156
+ :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
157
+ 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离
158
+ 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
159
+ """
160
+ error_file, relative_err, error_flag = None, None, False
161
+
162
+ data_name_pair = op_name_mapping_dict.get(npu_op_name)
163
+ npu_data_name = data_name_pair[0]
164
+ bench_data_name = data_name_pair[1]
165
+
166
+ if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据
167
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
168
+ elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据
169
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
170
+ error_file = 'no_bench_data'
171
+ elif str(bench_data_name) == CompareConst.N_A: # bench没匹配
172
+ n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
173
+ error_file = None
174
+ else:
175
+ npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR)
176
+ bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR)
177
+ try:
178
+ n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name,
179
+ self.cross_frame)
180
+ except IOError as error:
181
+ error_file = error.filename
182
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
183
+ error_flag = True
184
+ except (FileCheckException, CompareException):
185
+ error_file = data_name_pair
186
+ n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
187
+ error_flag = True
188
+
189
+ # 通过n_value, b_value同时得到错误标志和错误信息
190
+ n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
191
+ error_flag=error_flag, error_file=error_file)
192
+
193
+ result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
194
+
195
+ if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
196
+ err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
197
+ result_list.append(err_msg)
198
+ return result_list
199
+
200
+ def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
201
+ cos_result = []
202
+ euc_dist_result = []
203
+ max_err_result = []
204
+ max_relative_err_result = []
205
+ one_thousand_err_ratio_result = []
206
+ five_thousand_err_ratio_result = []
207
+ err_mess = []
208
+
209
+ is_print_compare_log = input_param.get("is_print_compare_log")
210
+
211
+ for i in range(len(result_df)):
212
+ npu_op_name = result_df.iloc[i, 0]
213
+ bench_op_name = result_df.iloc[i, 1]
214
+ if is_print_compare_log:
215
+ logger.info("start compare: {}".format(npu_op_name))
216
+
217
+ cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \
218
+ = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
219
+
220
+ if is_print_compare_log:
221
+ logger.info(
222
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
223
+ one_thousand_err_ratio {}, "
224
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
225
+ err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
226
+ cos_result.append(cos_sim)
227
+ euc_dist_result.append(euc_dist)
228
+ max_err_result.append(max_abs_err)
229
+ max_relative_err_result.append(max_relative_err)
230
+ one_thousand_err_ratio_result.append(one_thousand_err_ratio)
231
+ five_thousand_err_ratio_result.append(five_thousand_err_ratio)
232
+ err_mess.append(err_msg)
233
+
234
+ cr = ComparisonResult(
235
+ cos_result=cos_result,
236
+ euc_dist_result=euc_dist_result,
237
+ max_err_result=max_err_result,
238
+ max_relative_err_result=max_relative_err_result,
239
+ one_thousand_err_ratio_result=one_thousand_err_ratio_result,
240
+ five_thousand_err_ratio_result=five_thousand_err_ratio_result,
241
+ err_msgs=err_mess
242
+ )
243
+
244
+ return self._save_cmp_result(idx, cr, result_df, lock)
245
+
246
+ def do_multi_process(self, input_param, result_df):
247
+ try:
248
+ result_df = self._handle_multi_process(self.compare_ops, input_param, result_df,
249
+ multiprocessing.Manager().RLock())
250
+ return result_df
251
+ except ValueError as e:
252
+ logger.error('result dataframe is not found.')
253
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
254
+
255
+ def _handle_multi_process(self, func, input_param, result_df, lock):
256
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
257
+ op_name_mapping_dict = self.read_dump_data(result_df)
258
+
259
+ df_chunk_size = len(result_df) // process_num
260
+ if df_chunk_size > 0:
261
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
262
+ else:
263
+ df_chunks = [result_df]
264
+
265
+ results = []
266
+ pool = multiprocessing.Pool(process_num)
267
+
268
+ def err_call(args):
269
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
270
+ try:
271
+ pool.close()
272
+ except OSError:
273
+ logger.error("pool terminate failed")
274
+
275
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
276
+
277
+ def update_progress(size, progress_lock, extra_param=None):
278
+ with progress_lock:
279
+ progress_bar.update(size)
280
+
281
+ for process_idx, df_chunk in enumerate(df_chunks):
282
+ idx = df_chunk_size * process_idx
283
+ chunk_size = len(df_chunk)
284
+ result = pool.apply_async(func,
285
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_param),
286
+ error_callback=err_call,
287
+ callback=partial(update_progress, chunk_size, lock)
288
+ )
289
+ results.append(result)
290
+
291
+ final_results = [r.get() for r in results]
292
+ pool.close()
293
+ pool.join()
294
+ return pd.concat(final_results, ignore_index=True)
@@ -59,7 +59,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
59
59
  if error_file == "no_bench_data":
60
60
  err_msg = "Bench does not have data file."
61
61
  elif error_file:
62
- err_msg = f"Dump file: {error_file} not found."
62
+ err_msg = f"Dump file: {error_file} not found or read failed."
63
63
  else:
64
64
  err_msg = CompareConst.NO_BENCH
65
65
  error_flag = True
@@ -290,10 +290,8 @@ class CompareOps:
290
290
 
291
291
 
292
292
  def error_value_process(n_value):
293
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
293
+ if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE]:
294
294
  return CompareConst.UNSUPPORTED, ""
295
- if n_value == CompareConst.NONE:
296
- return 0, ""
297
295
  if n_value == CompareConst.SHAPE_UNMATCH:
298
296
  return CompareConst.SHAPE_UNMATCH, ""
299
297
  if n_value == CompareConst.NAN: