mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -30,12 +30,7 @@ from msprobe.core.common.file_utils import save_workbook
30
30
  from msprobe.core.common.log import logger
31
31
  from msprobe.core.common.utils import get_header_index, safe_get_value
32
32
  from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
33
-
34
-
35
- class HighlightCheck(abc.ABC):
36
- @abc.abstractmethod
37
- def apply(self, info, color_columns, dump_mode):
38
- raise NotImplementedError
33
+ from msprobe.core.compare.config import ModeConfig
39
34
 
40
35
 
41
36
  def add_highlight_row_info(color_list, num, highlight_err_msg):
@@ -46,6 +41,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg):
46
41
  color_list.append((num, [highlight_err_msg]))
47
42
 
48
43
 
44
+ class HighlightCheck(abc.ABC):
45
+ @abc.abstractmethod
46
+ def apply(self, info, color_columns, dump_mode):
47
+ raise NotImplementedError
48
+
49
+
49
50
  class CheckOrderMagnitude(HighlightCheck):
50
51
  """检查Max diff的数量级差异"""
51
52
 
@@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck):
75
76
  if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
76
77
  api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
77
78
  add_highlight_row_info(color_columns.red, num,
78
- "The input/parameters's one thousandth err ratio exceeds 0.9, "
79
+ "The input/parameter's one thousandth err ratio exceeds 0.9, "
79
80
  "while the output's is below 0.6")
80
81
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
81
82
  add_highlight_row_info(color_columns.yellow, num,
82
83
  "The output's one thousandth err ratio decreases by more than 0.1 "
83
- "compared to the input/parameters's")
84
+ "compared to the input/parameter's")
84
85
 
85
86
 
86
87
  class CheckCosineSimilarity(HighlightCheck):
@@ -94,7 +95,7 @@ class CheckCosineSimilarity(HighlightCheck):
94
95
  if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
95
96
  add_highlight_row_info(color_columns.yellow, num,
96
97
  "The output's cosine decreases by more than 0.1 "
97
- "compared to the input/parameters's")
98
+ "compared to the input/parameter's")
98
99
 
99
100
 
100
101
  class CheckMaxRelativeDiff(HighlightCheck):
@@ -117,7 +118,7 @@ class CheckMaxRelativeDiff(HighlightCheck):
117
118
  input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
118
119
  add_highlight_row_info(color_columns.yellow, num,
119
120
  "The output's maximum relative error exceeds 0.1, "
120
- "while the input/parameters's is below 0.01")
121
+ "while the input/parameter's is below 0.01")
121
122
 
122
123
 
123
124
  class CheckOverflow(HighlightCheck):
@@ -159,73 +160,6 @@ class HighlightRules:
159
160
  }
160
161
 
161
162
 
162
- def check_indices_numeric(api_items, indices: list):
163
- """检查指定索引处的值是否都为数字类型(int 或 float)"""
164
- return all(isinstance(api_items[i], (float, int)) for i in indices)
165
-
166
-
167
- def apply_comparison_rules(api_info, dump_mode, color_columns):
168
- """output与input/params的比较"""
169
- if dump_mode == Const.SUMMARY:
170
- for rule in HighlightRules.summary_compare_rules.values():
171
- rule.apply(api_info, color_columns, dump_mode)
172
- else:
173
- for rule in HighlightRules.compare_rules.values():
174
- rule.apply(api_info, color_columns, dump_mode)
175
-
176
-
177
- def find_error_rows(result, api_batch, highlight_dict, dump_mode):
178
- """找到单个API中需要高亮的行"""
179
- if dump_mode == Const.MD5:
180
- return
181
- npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
182
- bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
183
- max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
184
- else CompareConst.MAX_ABS_ERR, dump_mode)
185
-
186
- red_lines, yellow_lines = [], []
187
- LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
188
- ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
189
- ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
190
- color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
191
-
192
- api_batch_start = api_batch.start # result_df的input起始全局索引
193
- api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
194
- api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
195
- api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
196
- api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
197
-
198
- # 对单行API的输入或输出进行误差判断
199
- for i, line in enumerate(result):
200
- index = api_batch_start + i
201
- line_info = LineInfo(line_data=line, num_pointer=index)
202
- for rule in HighlightRules.basic_rules.values():
203
- rule.apply(line_info, color_columns, dump_mode)
204
-
205
- # 对API的输出与输入比较,进行误差判断
206
- for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
207
- index = api_batch_start + api_batch_params_slice_index_local + n
208
- # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
209
- if index in red_lines:
210
- continue
211
- if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
212
- continue
213
-
214
- # input/parameters的比较检查, 这里api_in包括input、parameters
215
- for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
216
- if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
217
- continue
218
- api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
219
- apply_comparison_rules(api_info, dump_mode, color_columns)
220
-
221
- red_lines_num_set = {x[0] for x in red_lines}
222
- yellow_lines_num_set = {x[0] for x in yellow_lines}
223
- highlight_dict.get('red_rows', set()).update(red_lines_num_set)
224
- highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
225
- highlight_dict.get('red_lines', []).extend(red_lines)
226
- highlight_dict.get('yellow_lines', []).extend(yellow_lines)
227
-
228
-
229
163
  class ApiBatch:
230
164
  def __init__(self, api_name: str, start: int):
231
165
  self.api_name = api_name
@@ -259,159 +193,225 @@ class ApiBatch:
259
193
  self.params_grad_end_index += 1
260
194
 
261
195
 
262
- def api_batches_update(api_batches, api_name, state, index):
263
- """
264
- 当一个api的所有item更新完后,input, output的索引范围:
265
- input: [start: start+input_len]
266
- output: [start+input_len: output_end_index]
267
- params: [output_end_index: params_end_index]
268
- """
269
- if not api_batches:
270
- api_batches.append(ApiBatch(api_name, index))
271
- else:
272
- api_batch = api_batches[-1]
273
- if api_batch.api_name == api_name or (
274
- not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
275
- try:
276
- api_batch.increment(state)
277
- except ValueError as e:
278
- logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
279
- raise CompareException(CompareException.INVALID_STATE_ERROR) from e
280
- else:
281
- api_batches.append(ApiBatch(api_name, index))
196
+ class HighLight:
197
+ def __init__(self, mode_config: ModeConfig):
198
+ self.mode_config = mode_config
282
199
 
283
-
284
- def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
285
- """将dataframe根据API分组,并找到有误差的算子用于高亮"""
286
- result = result_df.values
287
- api_batches = []
288
- for i, res_i in enumerate(result):
289
- api_full_name = safe_get_value(res_i, 0, "res_i")
290
- api_name, state = get_name_and_state(api_full_name)
291
- api_batches_update(api_batches, api_name, state, i)
292
- with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
293
- for api_batch in api_batches:
294
- find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
295
- dump_mode)
296
- progress_bar.update(1)
297
-
298
-
299
- def value_check(value, api_name=None, i=None, result_df_columns=None):
300
- if not table_value_is_valid(value):
301
- if result_df_columns:
302
- logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
303
- f"is not allowed to be written into the compare result xlsx.")
200
+ @staticmethod
201
+ def api_batches_update(api_batches, api_name, state, index):
202
+ """
203
+ 当一个api的所有item更新完后,input, output的索引范围:
204
+ input: [start: start+input_len]
205
+ output: [start+input_len: output_end_index]
206
+ params: [output_end_index: params_end_index]
207
+ """
208
+ if not api_batches:
209
+ api_batches.append(ApiBatch(api_name, index))
304
210
  else:
305
- logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
306
-
307
-
308
- def df_malicious_value_check(df_chunk, result_df_columns):
309
- for row in df_chunk.itertuples(index=False):
310
- api_name = row[0]
311
- for i, value in enumerate(row):
312
- value_check(value, api_name, i, result_df_columns)
313
-
314
-
315
- def handle_multi_process_malicious_value_check(func, result_df):
316
- result_total_nums = len(result_df)
317
- process_num = int((multiprocessing.cpu_count() + 1) / 2)
318
-
319
- if result_total_nums <= process_num:
320
- process_num = 1
321
- chunks = [result_df]
322
- else:
323
- chunk_size = result_total_nums // process_num
324
- chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
325
-
326
- pool = multiprocessing.Pool(process_num)
327
-
328
- def err_call(args):
329
- logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
330
- try:
331
- pool.terminate()
332
- except OSError:
333
- logger.error("Pool terminate failed")
334
-
335
- result_df_columns = result_df.columns.tolist()
336
- for column in result_df_columns:
337
- value_check(column)
338
- for df_chunk in chunks:
339
- pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
340
-
341
- pool.close()
342
- pool.join()
343
-
344
-
345
- def compare_result_df_convert(value):
346
- if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
347
- value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
348
- if isinstance(value, float):
349
- value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
350
- return value
351
-
352
-
353
- def highlight_rows_xlsx(result_df, highlight_dict, file_path):
354
- """Write and highlight results in Excel"""
211
+ api_batch = api_batches[-1]
212
+ if api_batch.api_name == api_name or (
213
+ not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
214
+ try:
215
+ api_batch.increment(state)
216
+ except ValueError as e:
217
+ logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
218
+ raise CompareException(CompareException.INVALID_STATE_ERROR) from e
219
+ else:
220
+ api_batches.append(ApiBatch(api_name, index))
221
+
222
+ @staticmethod
223
+ def check_indices_numeric(api_items, indices: list):
224
+ """检查指定索引处的值是否都为数字类型(int 或 float)"""
225
+ return all(isinstance(api_items[i], (float, int)) for i in indices)
226
+
227
+ @staticmethod
228
+ def update_highlight_err_msg(result_df, highlight_dict):
229
+ if result_df.shape[1] <= 1:
230
+ return
355
231
 
356
- update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
232
+ if CompareConst.NPU_MD5 in result_df.columns:
233
+ return
357
234
 
358
- wb = openpyxl.Workbook()
359
- ws = wb.active
235
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
236
+ red_lines_num_set = highlight_dict.get('red_rows')
237
+
238
+ for color in ['red', 'yellow']:
239
+ line_key = f'{color}_lines'
240
+ lines = highlight_dict.get(line_key, [])
241
+ for line_index, messages in lines:
242
+ if color == 'yellow' and line_index in red_lines_num_set:
243
+ continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
244
+
245
+ for msg in messages:
246
+ if err_msg[line_index] == '':
247
+ err_msg[line_index] = msg
248
+ else:
249
+ err_msg[line_index] += '\n' + msg
250
+
251
+ if color == 'red':
252
+ red_lines_num_set.add(line_index)
253
+
254
+ result_df[CompareConst.ERROR_MESSAGE] = err_msg
255
+
256
+ @staticmethod
257
+ def compare_result_df_convert(value):
258
+ if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
259
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
260
+ if isinstance(value, float):
261
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
262
+ return value
263
+
264
+ @staticmethod
265
+ def value_check(value, api_name=None, i=None, result_df_columns=None):
266
+ if not table_value_is_valid(value):
267
+ if result_df_columns:
268
+ logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
269
+ f"is not allowed to be written into the compare result xlsx.")
270
+ else:
271
+ logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
272
+
273
+ def find_compare_result_error_rows(self, result_df, highlight_dict):
274
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
275
+ result = result_df.values
276
+ api_batches = []
277
+ for i, res_i in enumerate(result):
278
+ api_full_name = safe_get_value(res_i, 0, "res_i")
279
+ api_name, state = get_name_and_state(api_full_name)
280
+ self.api_batches_update(api_batches, api_name, state, i)
281
+ with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
282
+ for api_batch in api_batches:
283
+ self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
284
+ highlight_dict)
285
+ progress_bar.update(1)
286
+
287
+ def find_error_rows(self, result, api_batch, highlight_dict):
288
+ """找到单个API中需要高亮的行"""
289
+ if self.mode_config.dump_mode == Const.MD5:
290
+ return
291
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode)
292
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode)
293
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY
294
+ else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode)
295
+
296
+ red_lines, yellow_lines = [], []
297
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
298
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
299
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
300
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
301
+
302
+ api_batch_start = api_batch.start # result_df的input起始全局索引
303
+ api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
304
+ api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
305
+ api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
306
+ api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
307
+
308
+ # 对单行API的输入或输出进行误差判断
309
+ for i, line in enumerate(result):
310
+ index = api_batch_start + i
311
+ line_info = LineInfo(line_data=line, num_pointer=index)
312
+ for rule in HighlightRules.basic_rules.values():
313
+ rule.apply(line_info, color_columns, self.mode_config.dump_mode)
314
+
315
+ # 对API的输出与输入比较,进行误差判断
316
+ for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
317
+ index = api_batch_start + api_batch_params_slice_index_local + n
318
+ # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
319
+ if index in red_lines:
320
+ continue
321
+ if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
322
+ continue
360
323
 
361
- # write header
362
- logger.info('Initializing Excel file.')
324
+ # input/parameters的比较检查, 这里api_in包括input、parameters
325
+ for api_in in result[0: api_batch_params_slice_index_local]:
326
+ if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
327
+ continue
328
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
329
+ self.apply_comparison_rules(api_info, color_columns)
330
+
331
+ red_lines_num_set = {x[0] for x in red_lines}
332
+ yellow_lines_num_set = {x[0] for x in yellow_lines}
333
+ highlight_dict.get('red_rows', set()).update(red_lines_num_set)
334
+ highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
335
+ highlight_dict.get('red_lines', []).extend(red_lines)
336
+ highlight_dict.get('yellow_lines', []).extend(yellow_lines)
337
+
338
+ def apply_comparison_rules(self, api_info, color_columns):
339
+ """output与input/params的比较"""
340
+ if self.mode_config.dump_mode == Const.SUMMARY:
341
+ for rule in HighlightRules.summary_compare_rules.values():
342
+ rule.apply(api_info, color_columns, self.mode_config.dump_mode)
343
+ else:
344
+ for rule in HighlightRules.compare_rules.values():
345
+ rule.apply(api_info, color_columns, self.mode_config.dump_mode)
363
346
 
364
- handle_multi_process_malicious_value_check(df_malicious_value_check, result_df)
347
+ def highlight_rows_xlsx(self, result_df, highlight_dict, file_path):
348
+ """Write and highlight results in Excel"""
365
349
 
366
- result_df_convert = result_df.applymap(compare_result_df_convert)
350
+ self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
367
351
 
368
- for row in dataframe_to_rows(result_df_convert, index=False, header=True):
369
- ws.append(row)
352
+ wb = openpyxl.Workbook()
353
+ ws = wb.active
370
354
 
371
- # 对可疑数据标色
372
- logger.info('Coloring Excel in progress.')
373
- col_len = len(result_df.columns)
374
- red_fill = PatternFill(
375
- start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
376
- )
377
- yellow_fill = PatternFill(
378
- start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
379
- )
380
- for i in highlight_dict.get("red_rows", []):
381
- for j in range(1, col_len + 1):
382
- ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
383
- for i in highlight_dict.get("yellow_rows", []):
384
- for j in range(1, col_len + 1):
385
- ws.cell(row=i + 2, column=j).fill = yellow_fill
355
+ # write header
356
+ logger.info('Initializing Excel file.')
386
357
 
387
- logger.info('Saving Excel file to disk: %s' % file_path)
388
- save_workbook(wb, file_path)
358
+ self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
389
359
 
360
+ result_df_convert = result_df.applymap(self.compare_result_df_convert)
390
361
 
391
- def update_highlight_err_msg(result_df, highlight_dict):
392
- if result_df.shape[1] <= 1:
393
- return
362
+ for row in dataframe_to_rows(result_df_convert, index=False, header=True):
363
+ ws.append(row)
394
364
 
395
- if CompareConst.NPU_MD5 in result_df.columns:
396
- return
365
+ # 对可疑数据标色
366
+ logger.info('Coloring Excel in progress.')
367
+ col_len = len(result_df.columns)
368
+ red_fill = PatternFill(
369
+ start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
370
+ )
371
+ yellow_fill = PatternFill(
372
+ start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
373
+ )
374
+ for i in highlight_dict.get("red_rows", []):
375
+ for j in range(1, col_len + 1):
376
+ ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
377
+ for i in highlight_dict.get("yellow_rows", []):
378
+ for j in range(1, col_len + 1):
379
+ ws.cell(row=i + 2, column=j).fill = yellow_fill
397
380
 
398
- err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
399
- red_lines_num_set = highlight_dict.get('red_rows')
381
+ logger.info('Saving Excel file to disk: %s' % file_path)
382
+ save_workbook(wb, file_path)
400
383
 
401
- for color in ['red', 'yellow']:
402
- line_key = f'{color}_lines'
403
- lines = highlight_dict.get(line_key, [])
404
- for line_index, messages in lines:
405
- if color == 'yellow' and line_index in red_lines_num_set:
406
- continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
384
+ def handle_multi_process_malicious_value_check(self, func, result_df):
385
+ result_total_nums = len(result_df)
386
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
407
387
 
408
- for msg in messages:
409
- if err_msg[line_index] == '':
410
- err_msg[line_index] = msg
411
- else:
412
- err_msg[line_index] += '\n' + msg
388
+ if result_total_nums <= process_num:
389
+ process_num = 1
390
+ chunks = [result_df]
391
+ else:
392
+ chunk_size = result_total_nums // process_num
393
+ chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
413
394
 
414
- if color == 'red':
415
- red_lines_num_set.add(line_index)
395
+ pool = multiprocessing.Pool(process_num)
416
396
 
417
- result_df[CompareConst.ERROR_MESSAGE] = err_msg
397
+ def err_call(args):
398
+ logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
399
+ try:
400
+ pool.close()
401
+ except OSError:
402
+ logger.error("Pool terminate failed")
403
+
404
+ result_df_columns = result_df.columns.tolist()
405
+ for column in result_df_columns:
406
+ self.value_check(column)
407
+ for df_chunk in chunks:
408
+ pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
409
+
410
+ pool.close()
411
+ pool.join()
412
+
413
+ def df_malicious_value_check(self, df_chunk, result_df_columns):
414
+ for row in df_chunk.itertuples(index=False):
415
+ api_name = row[0]
416
+ for i, value in enumerate(row):
417
+ self.value_check(value, api_name, i, result_df_columns)
@@ -164,6 +164,8 @@ def preprocess_layer_mapping(mapping):
164
164
  for key, value in name_map.items():
165
165
  key_list = key.split('.')
166
166
  prefix = key_list[0] # 取前缀
167
+ value_list = value.split('(')
168
+ value = value_list[0] # 取前缀
167
169
  key_len = len(key_list)
168
170
  if prefix not in final_mapping[type_name]:
169
171
  final_mapping[type_name][prefix] = []
@@ -33,8 +33,8 @@ def check_compare_result_name(file_name):
33
33
  """
34
34
  check whether the compare result name is as expected
35
35
  """
36
- single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$"
37
- multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$"
36
+ single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$"
37
+ multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$"
38
38
  if re.match(multi_ranks_pattern, file_name):
39
39
  return True
40
40
  if re.match(single_rank_pattern, file_name):
@@ -48,7 +48,7 @@ def reorder_path(compare_result_path_list):
48
48
  """
49
49
  reorder compare results by rank num
50
50
  """
51
- rank_pattern = r"compare_result_rank(\d+)-rank"
51
+ rank_pattern = r"compare_result_rank(\d+)"
52
52
  reorder_path_list = sorted(
53
53
  compare_result_path_list,
54
54
  key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
@@ -238,7 +238,7 @@ def handle_multi_process(func, func_args, lock):
238
238
  def err_call(args):
239
239
  logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
240
240
  try:
241
- pool.terminate()
241
+ pool.close()
242
242
  except OSError:
243
243
  logger.error("Pool terminate failed")
244
244