mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -30,12 +30,7 @@ from msprobe.core.common.file_utils import save_workbook
30
30
  from msprobe.core.common.log import logger
31
31
  from msprobe.core.common.utils import get_header_index, safe_get_value
32
32
  from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
33
-
34
-
35
- class HighlightCheck(abc.ABC):
36
- @abc.abstractmethod
37
- def apply(self, info, color_columns, dump_mode):
38
- raise NotImplementedError
33
+ from msprobe.core.compare.config import ModeConfig
39
34
 
40
35
 
41
36
  def add_highlight_row_info(color_list, num, highlight_err_msg):
@@ -46,6 +41,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg):
46
41
  color_list.append((num, [highlight_err_msg]))
47
42
 
48
43
 
44
+ class HighlightCheck(abc.ABC):
45
+ @abc.abstractmethod
46
+ def apply(self, info, color_columns, dump_mode):
47
+ raise NotImplementedError
48
+
49
+
49
50
  class CheckOrderMagnitude(HighlightCheck):
50
51
  """检查Max diff的数量级差异"""
51
52
 
@@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck):
75
76
  if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
76
77
  api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
77
78
  add_highlight_row_info(color_columns.red, num,
78
- "The input/parameters's one thousandth err ratio exceeds 0.9, "
79
+ "The input/parameter's one thousandth err ratio exceeds 0.9, "
79
80
  "while the output's is below 0.6")
80
81
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
81
82
  add_highlight_row_info(color_columns.yellow, num,
82
83
  "The output's one thousandth err ratio decreases by more than 0.1 "
83
- "compared to the input/parameters's")
84
+ "compared to the input/parameter's")
84
85
 
85
86
 
86
87
  class CheckCosineSimilarity(HighlightCheck):
@@ -94,7 +95,7 @@ class CheckCosineSimilarity(HighlightCheck):
94
95
  if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
95
96
  add_highlight_row_info(color_columns.yellow, num,
96
97
  "The output's cosine decreases by more than 0.1 "
97
- "compared to the input/parameters's")
98
+ "compared to the input/parameter's")
98
99
 
99
100
 
100
101
  class CheckMaxRelativeDiff(HighlightCheck):
@@ -117,7 +118,7 @@ class CheckMaxRelativeDiff(HighlightCheck):
117
118
  input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
118
119
  add_highlight_row_info(color_columns.yellow, num,
119
120
  "The output's maximum relative error exceeds 0.1, "
120
- "while the input/parameters's is below 0.01")
121
+ "while the input/parameter's is below 0.01")
121
122
 
122
123
 
123
124
  class CheckOverflow(HighlightCheck):
@@ -146,84 +147,19 @@ class HighlightRules:
146
147
  }
147
148
 
148
149
  # 用于比较输入和输出的规则
150
+ # 真实数据检查规则
149
151
  compare_rules = {
150
152
  "check_order_magnitude": CheckOrderMagnitude(),
151
153
  "check_one_thousand_error": CheckOneThousandErrorRatio(),
152
154
  "check_cosine_similarity": CheckCosineSimilarity()
153
155
  }
156
+ # 统计量数据检查规则
154
157
  summary_compare_rules = {
155
158
  "check_order_magnitude": CheckOrderMagnitude(),
156
159
  "check_max_relative_diff": CheckMaxRelativeDiff(),
157
160
  }
158
161
 
159
162
 
160
- def check_indices_numeric(api_items, indices: list):
161
- """检查指定索引处的值是否都为数字类型(int 或 float)"""
162
- return all(isinstance(api_items[i], (float, int)) for i in indices)
163
-
164
-
165
- def apply_comparison_rules(api_info, dump_mode, color_columns):
166
- """output与input/params的比较"""
167
- if dump_mode == Const.SUMMARY:
168
- for rule in HighlightRules.summary_compare_rules.values():
169
- rule.apply(api_info, color_columns, dump_mode)
170
- else:
171
- for rule in HighlightRules.compare_rules.values():
172
- rule.apply(api_info, color_columns, dump_mode)
173
-
174
-
175
- def find_error_rows(result, api_batch, highlight_dict, dump_mode):
176
- """找到单个API中需要高亮的行"""
177
- if dump_mode == Const.MD5:
178
- return
179
- npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
180
- bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
181
- max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
182
- else CompareConst.MAX_ABS_ERR, dump_mode)
183
-
184
- red_lines, yellow_lines = [], []
185
- LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
186
- ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
187
- ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
188
- color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
189
-
190
- api_batch_start = api_batch.start # result_df的input起始全局索引
191
- api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
192
- api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
193
- api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
194
- api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
195
-
196
- # 对单行API的输入或输出进行误差判断
197
- for i, line in enumerate(result):
198
- index = api_batch_start + i
199
- line_info = LineInfo(line_data=line, num_pointer=index)
200
- for rule in HighlightRules.basic_rules.values():
201
- rule.apply(line_info, color_columns, dump_mode)
202
-
203
- # 对API的输出与输入比较,进行误差判断
204
- for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
205
- index = api_batch_start + api_batch_params_slice_index_local + n
206
- # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
207
- if index in red_lines:
208
- continue
209
- if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
210
- continue
211
-
212
- # input/parameters的比较检查, 这里api_in包括input、parameters
213
- for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
214
- if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
215
- continue
216
- api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
217
- apply_comparison_rules(api_info, dump_mode, color_columns)
218
-
219
- red_lines_num_set = {x[0] for x in red_lines}
220
- yellow_lines_num_set = {x[0] for x in yellow_lines}
221
- highlight_dict.get('red_rows', set()).update(red_lines_num_set)
222
- highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
223
- highlight_dict.get('red_lines', []).extend(red_lines)
224
- highlight_dict.get('yellow_lines', []).extend(yellow_lines)
225
-
226
-
227
163
  class ApiBatch:
228
164
  def __init__(self, api_name: str, start: int):
229
165
  self.api_name = api_name
@@ -257,159 +193,225 @@ class ApiBatch:
257
193
  self.params_grad_end_index += 1
258
194
 
259
195
 
260
- def api_batches_update(api_batches, api_name, state, index):
261
- """
262
- 当一个api的所有item更新完后,input, output的索引范围:
263
- input: [start: start+input_len]
264
- output: [start+input_len: output_end_index]
265
- params: [output_end_index: params_end_index]
266
- """
267
- if not api_batches:
268
- api_batches.append(ApiBatch(api_name, index))
269
- else:
270
- api_batch = api_batches[-1]
271
- if api_batch.api_name == api_name or (
272
- not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
273
- try:
274
- api_batch.increment(state)
275
- except ValueError as e:
276
- logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
277
- raise CompareException(CompareException.INVALID_STATE_ERROR) from e
278
- else:
279
- api_batches.append(ApiBatch(api_name, index))
196
+ class HighLight:
197
+ def __init__(self, mode_config: ModeConfig):
198
+ self.mode_config = mode_config
280
199
 
281
-
282
- def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
283
- """将dataframe根据API分组,并找到有误差的算子用于高亮"""
284
- result = result_df.values
285
- api_batches = []
286
- for i, res_i in enumerate(result):
287
- api_full_name = safe_get_value(res_i, 0, "res_i")
288
- api_name, state = get_name_and_state(api_full_name)
289
- api_batches_update(api_batches, api_name, state, i)
290
- with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
291
- for api_batch in api_batches:
292
- find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
293
- dump_mode)
294
- progress_bar.update(1)
295
-
296
-
297
- def value_check(value, api_name=None, i=None, result_df_columns=None):
298
- if not table_value_is_valid(value):
299
- if result_df_columns:
300
- logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
301
- f"is not allowed to be written into the compare result xlsx.")
200
+ @staticmethod
201
+ def api_batches_update(api_batches, api_name, state, index):
202
+ """
203
+ 当一个api的所有item更新完后,input, output的索引范围:
204
+ input: [start: start+input_len]
205
+ output: [start+input_len: output_end_index]
206
+ params: [output_end_index: params_end_index]
207
+ """
208
+ if not api_batches:
209
+ api_batches.append(ApiBatch(api_name, index))
302
210
  else:
303
- logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
304
-
305
-
306
- def df_malicious_value_check(df_chunk, result_df_columns):
307
- for row in df_chunk.itertuples(index=False):
308
- api_name = row[0]
309
- for i, value in enumerate(row):
310
- value_check(value, api_name, i, result_df_columns)
311
-
312
-
313
- def handle_multi_process_malicious_value_check(func, result_df):
314
- result_total_nums = len(result_df)
315
- process_num = int((multiprocessing.cpu_count() + 1) / 2)
316
-
317
- if result_total_nums <= process_num:
318
- process_num = 1
319
- chunks = [result_df]
320
- else:
321
- chunk_size = result_total_nums // process_num
322
- chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
323
-
324
- pool = multiprocessing.Pool(process_num)
325
-
326
- def err_call(args):
327
- logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
328
- try:
329
- pool.terminate()
330
- except OSError:
331
- logger.error("Pool terminate failed")
332
-
333
- result_df_columns = result_df.columns.tolist()
334
- for column in result_df_columns:
335
- value_check(column)
336
- for df_chunk in chunks:
337
- pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
338
-
339
- pool.close()
340
- pool.join()
341
-
342
-
343
- def compare_result_df_convert(value):
344
- if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
345
- value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
346
- if isinstance(value, float):
347
- value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
348
- return value
349
-
350
-
351
- def highlight_rows_xlsx(result_df, highlight_dict, file_path):
352
- """Write and highlight results in Excel"""
211
+ api_batch = api_batches[-1]
212
+ if api_batch.api_name == api_name or (
213
+ not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
214
+ try:
215
+ api_batch.increment(state)
216
+ except ValueError as e:
217
+ logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
218
+ raise CompareException(CompareException.INVALID_STATE_ERROR) from e
219
+ else:
220
+ api_batches.append(ApiBatch(api_name, index))
221
+
222
+ @staticmethod
223
+ def check_indices_numeric(api_items, indices: list):
224
+ """检查指定索引处的值是否都为数字类型(int 或 float)"""
225
+ return all(isinstance(api_items[i], (float, int)) for i in indices)
226
+
227
+ @staticmethod
228
+ def update_highlight_err_msg(result_df, highlight_dict):
229
+ if result_df.shape[1] <= 1:
230
+ return
353
231
 
354
- update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
232
+ if CompareConst.NPU_MD5 in result_df.columns:
233
+ return
355
234
 
356
- wb = openpyxl.Workbook()
357
- ws = wb.active
235
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
236
+ red_lines_num_set = highlight_dict.get('red_rows')
237
+
238
+ for color in ['red', 'yellow']:
239
+ line_key = f'{color}_lines'
240
+ lines = highlight_dict.get(line_key, [])
241
+ for line_index, messages in lines:
242
+ if color == 'yellow' and line_index in red_lines_num_set:
243
+ continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
244
+
245
+ for msg in messages:
246
+ if err_msg[line_index] == '':
247
+ err_msg[line_index] = msg
248
+ else:
249
+ err_msg[line_index] += '\n' + msg
250
+
251
+ if color == 'red':
252
+ red_lines_num_set.add(line_index)
253
+
254
+ result_df[CompareConst.ERROR_MESSAGE] = err_msg
255
+
256
+ @staticmethod
257
+ def compare_result_df_convert(value):
258
+ if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
259
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
260
+ if isinstance(value, float):
261
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
262
+ return value
263
+
264
+ @staticmethod
265
+ def value_check(value, api_name=None, i=None, result_df_columns=None):
266
+ if not table_value_is_valid(value):
267
+ if result_df_columns:
268
+ logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
269
+ f"is not allowed to be written into the compare result xlsx.")
270
+ else:
271
+ logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
272
+
273
+ def find_compare_result_error_rows(self, result_df, highlight_dict):
274
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
275
+ result = result_df.values
276
+ api_batches = []
277
+ for i, res_i in enumerate(result):
278
+ api_full_name = safe_get_value(res_i, 0, "res_i")
279
+ api_name, state = get_name_and_state(api_full_name)
280
+ self.api_batches_update(api_batches, api_name, state, i)
281
+ with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
282
+ for api_batch in api_batches:
283
+ self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
284
+ highlight_dict)
285
+ progress_bar.update(1)
286
+
287
+ def find_error_rows(self, result, api_batch, highlight_dict):
288
+ """找到单个API中需要高亮的行"""
289
+ if self.mode_config.dump_mode == Const.MD5:
290
+ return
291
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode)
292
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode)
293
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY
294
+ else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode)
295
+
296
+ red_lines, yellow_lines = [], []
297
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
298
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
299
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
300
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
301
+
302
+ api_batch_start = api_batch.start # result_df的input起始全局索引
303
+ api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
304
+ api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
305
+ api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
306
+ api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
307
+
308
+ # 对单行API的输入或输出进行误差判断
309
+ for i, line in enumerate(result):
310
+ index = api_batch_start + i
311
+ line_info = LineInfo(line_data=line, num_pointer=index)
312
+ for rule in HighlightRules.basic_rules.values():
313
+ rule.apply(line_info, color_columns, self.mode_config.dump_mode)
314
+
315
+ # 对API的输出与输入比较,进行误差判断
316
+ for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
317
+ index = api_batch_start + api_batch_params_slice_index_local + n
318
+ # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
319
+ if index in red_lines:
320
+ continue
321
+ if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
322
+ continue
358
323
 
359
- # write header
360
- logger.info('Initializing Excel file.')
324
+ # input/parameters的比较检查, 这里api_in包括input、parameters
325
+ for api_in in result[0: api_batch_params_slice_index_local]:
326
+ if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
327
+ continue
328
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
329
+ self.apply_comparison_rules(api_info, color_columns)
330
+
331
+ red_lines_num_set = {x[0] for x in red_lines}
332
+ yellow_lines_num_set = {x[0] for x in yellow_lines}
333
+ highlight_dict.get('red_rows', set()).update(red_lines_num_set)
334
+ highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
335
+ highlight_dict.get('red_lines', []).extend(red_lines)
336
+ highlight_dict.get('yellow_lines', []).extend(yellow_lines)
337
+
338
+ def apply_comparison_rules(self, api_info, color_columns):
339
+ """output与input/params的比较"""
340
+ if self.mode_config.dump_mode == Const.SUMMARY:
341
+ for rule in HighlightRules.summary_compare_rules.values():
342
+ rule.apply(api_info, color_columns, self.mode_config.dump_mode)
343
+ else:
344
+ for rule in HighlightRules.compare_rules.values():
345
+ rule.apply(api_info, color_columns, self.mode_config.dump_mode)
361
346
 
362
- handle_multi_process_malicious_value_check(df_malicious_value_check, result_df)
347
+ def highlight_rows_xlsx(self, result_df, highlight_dict, file_path):
348
+ """Write and highlight results in Excel"""
363
349
 
364
- result_df_convert = result_df.applymap(compare_result_df_convert)
350
+ self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
365
351
 
366
- for row in dataframe_to_rows(result_df_convert, index=False, header=True):
367
- ws.append(row)
352
+ wb = openpyxl.Workbook()
353
+ ws = wb.active
368
354
 
369
- # 对可疑数据标色
370
- logger.info('Coloring Excel in progress.')
371
- col_len = len(result_df.columns)
372
- red_fill = PatternFill(
373
- start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
374
- )
375
- yellow_fill = PatternFill(
376
- start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
377
- )
378
- for i in highlight_dict.get("red_rows", []):
379
- for j in range(1, col_len + 1):
380
- ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
381
- for i in highlight_dict.get("yellow_rows", []):
382
- for j in range(1, col_len + 1):
383
- ws.cell(row=i + 2, column=j).fill = yellow_fill
355
+ # write header
356
+ logger.info('Initializing Excel file.')
384
357
 
385
- logger.info('Saving Excel file to disk: %s' % file_path)
386
- save_workbook(wb, file_path)
358
+ self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
387
359
 
360
+ result_df_convert = result_df.applymap(self.compare_result_df_convert)
388
361
 
389
- def update_highlight_err_msg(result_df, highlight_dict):
390
- if result_df.shape[1] <= 1:
391
- return
362
+ for row in dataframe_to_rows(result_df_convert, index=False, header=True):
363
+ ws.append(row)
392
364
 
393
- if CompareConst.NPU_MD5 in result_df.columns:
394
- return
365
+ # 对可疑数据标色
366
+ logger.info('Coloring Excel in progress.')
367
+ col_len = len(result_df.columns)
368
+ red_fill = PatternFill(
369
+ start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
370
+ )
371
+ yellow_fill = PatternFill(
372
+ start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
373
+ )
374
+ for i in highlight_dict.get("red_rows", []):
375
+ for j in range(1, col_len + 1):
376
+ ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
377
+ for i in highlight_dict.get("yellow_rows", []):
378
+ for j in range(1, col_len + 1):
379
+ ws.cell(row=i + 2, column=j).fill = yellow_fill
395
380
 
396
- err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
397
- red_lines_num_set = highlight_dict.get('red_rows')
381
+ logger.info('Saving Excel file to disk: %s' % file_path)
382
+ save_workbook(wb, file_path)
398
383
 
399
- for color in ['red', 'yellow']:
400
- line_key = f'{color}_lines'
401
- lines = highlight_dict.get(line_key, [])
402
- for line_index, messages in lines:
403
- if color == 'yellow' and line_index in red_lines_num_set:
404
- continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
384
+ def handle_multi_process_malicious_value_check(self, func, result_df):
385
+ result_total_nums = len(result_df)
386
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
405
387
 
406
- for msg in messages:
407
- if err_msg[line_index] == '':
408
- err_msg[line_index] = msg
409
- else:
410
- err_msg[line_index] += '\n' + msg
388
+ if result_total_nums <= process_num:
389
+ process_num = 1
390
+ chunks = [result_df]
391
+ else:
392
+ chunk_size = result_total_nums // process_num
393
+ chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
411
394
 
412
- if color == 'red':
413
- red_lines_num_set.add(line_index)
395
+ pool = multiprocessing.Pool(process_num)
414
396
 
415
- result_df[CompareConst.ERROR_MESSAGE] = err_msg
397
+ def err_call(args):
398
+ logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
399
+ try:
400
+ pool.close()
401
+ except OSError:
402
+ logger.error("Pool terminate failed")
403
+
404
+ result_df_columns = result_df.columns.tolist()
405
+ for column in result_df_columns:
406
+ self.value_check(column)
407
+ for df_chunk in chunks:
408
+ pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
409
+
410
+ pool.close()
411
+ pool.join()
412
+
413
+ def df_malicious_value_check(self, df_chunk, result_df_columns):
414
+ for row in df_chunk.itertuples(index=False):
415
+ api_name = row[0]
416
+ for i, value in enumerate(row):
417
+ self.value_check(value, api_name, i, result_df_columns)
@@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml,
23
23
  get_stack_construct_by_dump_json_path)
24
24
  from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
25
25
  from msprobe.core.compare.utils import read_op, reorder_op_name_list
26
-
26
+ from msprobe.core.common.decorator import recursion_depth_decorator
27
27
 
28
28
 
29
29
  class LayerTrie:
@@ -71,6 +71,7 @@ class LayerTrie:
71
71
  file_path = os.path.join(os.path.realpath(output_path), file_name)
72
72
  save_yaml(file_path, result)
73
73
 
74
+ @recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100)
74
75
  def convert_to_dict(self, node):
75
76
  result = {}
76
77
  result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
@@ -163,6 +164,8 @@ def preprocess_layer_mapping(mapping):
163
164
  for key, value in name_map.items():
164
165
  key_list = key.split('.')
165
166
  prefix = key_list[0] # 取前缀
167
+ value_list = value.split('(')
168
+ value = value_list[0] # 取前缀
166
169
  key_len = len(key_list)
167
170
  if prefix not in final_mapping[type_name]:
168
171
  final_mapping[type_name][prefix] = []
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +21,8 @@ from functools import partial
21
21
  import pandas as pd
22
22
  from tqdm import tqdm
23
23
 
24
- from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
24
+ from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \
25
+ remove_path
25
26
  from msprobe.core.common.const import FileCheckConst, Const, CompareConst
26
27
  from msprobe.core.common.utils import CompareException, add_time_with_xlsx
27
28
  from msprobe.core.compare.utils import table_value_is_valid
@@ -32,8 +33,8 @@ def check_compare_result_name(file_name):
32
33
  """
33
34
  check whether the compare result name is as expected
34
35
  """
35
- single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$"
36
- multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$"
36
+ single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$"
37
+ multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$"
37
38
  if re.match(multi_ranks_pattern, file_name):
38
39
  return True
39
40
  if re.match(single_rank_pattern, file_name):
@@ -47,7 +48,7 @@ def reorder_path(compare_result_path_list):
47
48
  """
48
49
  reorder compare results by rank num
49
50
  """
50
- rank_pattern = r"compare_result_rank(\d+)-rank"
51
+ rank_pattern = r"compare_result_rank(\d+)"
51
52
  reorder_path_list = sorted(
52
53
  compare_result_path_list,
53
54
  key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
@@ -63,6 +64,7 @@ def get_result_path(input_dir):
63
64
  for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
64
65
  filt_compare_result_path_list = []
65
66
  for file_path in compare_result_path_list:
67
+ FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
66
68
  file_name = os.path.basename(file_path)
67
69
  if check_compare_result_name(file_name):
68
70
  compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
@@ -236,7 +238,7 @@ def handle_multi_process(func, func_args, lock):
236
238
  def err_call(args):
237
239
  logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
238
240
  try:
239
- pool.terminate()
241
+ pool.close()
240
242
  except OSError:
241
243
  logger.error("Pool terminate failed")
242
244
 
@@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co
329
331
  for i, df in enumerate(merge_df_list):
330
332
  # merge_df_list中df与compare_index_list中compare_index一一对应
331
333
  final_result_df_list.append((df, compare_index_list[i]))
334
+
335
+ if os.path.exists(output_path):
336
+ logger.warning(f"{output_path} will be deleted.")
337
+ remove_path(output_path)
332
338
  save_excel(output_path, final_result_df_list)
333
339
  logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
334
340