mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -22,10 +22,10 @@ import re
22
22
 
23
23
  import torch
24
24
 
25
- from msprobe.core.common.const import MonitorConst, Const
25
+ from msprobe.core.common.const import MonitorConst
26
26
  from msprobe.pytorch.common.log import logger
27
27
  from msprobe.core.common.utils import is_int
28
- from msprobe.core.common.file_utils import check_file_or_directory_path
28
+ from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
29
29
 
30
30
 
31
31
  device = "cpu"
@@ -43,7 +43,6 @@ DIRECTORY_MAX_LENGTH = 4096
43
43
 
44
44
  beijing_tz = timezone(timedelta(hours=8))
45
45
  MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
46
- MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad"))
47
46
 
48
47
 
49
48
  class MsgConst:
@@ -102,9 +101,23 @@ def validate_ops(ops):
102
101
  default_op = MonitorConst.OP_LIST[0]
103
102
  valid_ops.append(default_op)
104
103
  logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
104
+ # 增加默认shape和dtype参数
105
+ if "shape" not in valid_ops:
106
+ valid_ops.append("shape")
107
+ if "dtype" not in valid_ops:
108
+ valid_ops.append("dtype")
105
109
  return valid_ops
106
110
 
107
111
 
112
+ def validate_ndigits(ndigits):
113
+ if not ndigits:
114
+ return
115
+ if not is_int(ndigits) or ndigits <= 0:
116
+ raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
117
+ if ndigits > MonitorConst.MAX_NDIGITS:
118
+ raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
119
+
120
+
108
121
  def validate_ranks(ranks):
109
122
  if not isinstance(ranks, list):
110
123
  raise TypeError("module_ranks should be a list")
@@ -190,7 +203,7 @@ def validate_alert(alert):
190
203
  args = rule.get("args")
191
204
  if args and isinstance(args, dict):
192
205
  threshold = args.get("threshold")
193
- if not isinstance(threshold, float) or threshold < 0:
206
+ if not isinstance(threshold, (float, int)) or threshold < 0:
194
207
  raise TypeError('threshold must be float and not less than 0')
195
208
  dump = alert.get('dump')
196
209
  if dump and not isinstance(dump, bool):
@@ -206,9 +219,24 @@ def validate_step_count_per_record(step_count_per_record):
206
219
  raise ValueError("step_count_per_record must smaller than 1e6")
207
220
 
208
221
 
222
+ def validate_dynamic_on(dynamic_on):
223
+ if not isinstance(dynamic_on, bool):
224
+ raise TypeError('dynamic_on should be a bool')
225
+
226
+
227
+ def validate_monitor_mbs_grad(monitor_mbs_grad):
228
+ if not isinstance(monitor_mbs_grad, bool):
229
+ logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
230
+ return False
231
+ return monitor_mbs_grad
232
+
233
+
209
234
  def validate_config(config):
210
235
  config['ops'] = validate_ops(config.get('ops', []))
211
236
 
237
+ ndigits = config.get('ndigits')
238
+ validate_ndigits(ndigits)
239
+
212
240
  eps = config.get('eps', 1e-8)
213
241
  if not isinstance(eps, float):
214
242
  raise TypeError("eps should be a float")
@@ -246,9 +274,22 @@ def validate_config(config):
246
274
  step_count_per_record = config.get('step_count_per_record', 1)
247
275
  validate_step_count_per_record(step_count_per_record)
248
276
 
277
+ config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
278
+ MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
279
+ config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
280
+ MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
281
+ MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
282
+ config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
283
+ MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
284
+
249
285
  squash_name = config.get('squash_name', True)
250
286
  validate_squash_name(squash_name)
251
287
 
288
+ config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
289
+
290
+ dynamic_on = config.get('dynamic_on', False)
291
+ validate_dynamic_on(dynamic_on)
292
+
252
293
  if not targets:
253
294
  if xy_distribution:
254
295
  config["all_xy"] = True
@@ -257,6 +298,8 @@ def validate_config(config):
257
298
 
258
299
  def time_str2time_digit(time_str):
259
300
  time_format = '%b%d_%H-%M-%S'
301
+ if not isinstance(time_str, str):
302
+ raise TypeError(f"time_str:{time_str} should be a str")
260
303
  try:
261
304
  time_digit = datetime.strptime(time_str, time_format)
262
305
  except Exception as e:
@@ -284,3 +327,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
284
327
  if start_ok and end_ok:
285
328
  result[rank] = os.path.join(monitor_path, dirname)
286
329
  return result
330
+
331
+
332
+ def chmod_tensorboard_dir(path):
333
+ """
334
+ format配置为tensorboard时,需要补充文件权限设置
335
+ """
336
+ try:
337
+ recursive_chmod(path)
338
+ except Exception as e:
339
+ logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
340
+
341
+
342
+ def validate_set_monitor(grad_acc_steps, start_iteration):
343
+ """
344
+ validate parameters of set_monitor.
345
+ """
346
+ grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
347
+ MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
348
+
349
+ start_iteration = validate_int_arg(start_iteration, "start_iteration",
350
+ MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
351
+ return grad_acc_steps, start_iteration
352
+
353
+
354
+ def validate_int_arg(value, name, minimum, default_value):
355
+ """Validate int args, if any exception occurs, use the default value."""
356
+ if value is None:
357
+ return default_value
358
+ try:
359
+ if not is_int(value):
360
+ raise TypeError(f"{name} must be int")
361
+ if value < minimum:
362
+ raise ValueError(f"{name} must greater than {minimum}")
363
+ except Exception as e:
364
+ value = default_value
365
+ logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
366
+ return value
@@ -125,8 +125,6 @@ class Saver:
125
125
 
126
126
  def write_summary_csv(self, test_result):
127
127
  test_rows = []
128
- if self.stack_info:
129
- test_rows[0].append(self.COLUMN_STACK_INFO)
130
128
 
131
129
  check_op_str_pattern_valid(test_result.api_name)
132
130
  df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
@@ -16,6 +16,7 @@
16
16
  import json
17
17
  import os
18
18
  import time
19
+ import multiprocessing
19
20
  from multiprocessing import Pool
20
21
 
21
22
  import torch
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
52
53
  return
53
54
  if dump_path is None:
54
55
  logger.error("Please set dump_path when dump_mode is config!")
56
+ raise DispatchException("Please set dump_path when dump_mode is config!")
55
57
  check_file_or_directory_path(dump_path, True)
56
58
 
57
59
  self.device_id = torch_npu._C._npu_getDevice()
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
85
87
  self.get_ops(yaml_path)
86
88
 
87
89
  self.lock = None
90
+ max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
+ if process_num > max_process_num:
92
+ logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
+ raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
+ f'but got {process_num}!')
88
95
  if process_num > 0:
89
96
  self.pool = Pool(process_num)
90
97
  if debug:
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
115
122
  if len(json_line_data) == 0:
116
123
  break
117
124
  msg = json.loads(json_line_data)
125
+ if len(msg) < 2:
126
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
118
127
  self.all_summary[msg[0]] = msg[1]
119
128
  fp_handle.close()
120
129
 
@@ -199,8 +208,10 @@ class PtdbgDispatch(TorchDispatchMode):
199
208
  dispatch_workflow(run_param, data_info)
200
209
  else:
201
210
  self.lock.acquire()
202
- self.all_summary.append([])
203
- self.lock.release()
211
+ try:
212
+ self.all_summary.append([])
213
+ finally:
214
+ self.lock.release()
204
215
  run_param.process_flag = True
205
216
  if self.check_fun(func, run_param):
206
217
  data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
@@ -19,6 +19,8 @@ import os
19
19
  from datetime import datetime, timezone
20
20
 
21
21
  import torch
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.decorator import recursion_depth_decorator
22
24
  from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
23
25
  from msprobe.pytorch.common.log import logger
24
26
 
@@ -91,6 +93,7 @@ def support_basic_type(data):
91
93
  return False
92
94
 
93
95
 
96
+ @recursion_depth_decorator("dump_data")
94
97
  def dump_data(data, prefix, dump_path):
95
98
  if isinstance(data, (tuple, list)) and data:
96
99
  for i, item in enumerate(data):
@@ -107,8 +110,11 @@ def dump_data(data, prefix, dump_path):
107
110
  def save_temp_summary(api_index, single_api_summary, path, lock):
108
111
  summary_path = os.path.join(path, f'summary.json')
109
112
  lock.acquire()
110
- data = [api_index, single_api_summary]
111
- save_json(summary_path, data, mode='a')
113
+ try:
114
+ data = [api_index, single_api_summary]
115
+ save_json(summary_path, data, mode='a')
116
+ finally:
117
+ lock.release()
112
118
 
113
119
 
114
120
  def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
@@ -27,8 +27,10 @@ else:
27
27
  pta_cpu_device = torch.device("cpu")
28
28
 
29
29
  from msprobe.core.common.const import CompareConst
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
  from msprobe.pytorch.common.log import logger
31
32
 
33
+
32
34
  cpu_device = torch._C.device("cpu")
33
35
  COLOR_RED = '\033[31m'
34
36
  COLOR_GREEN = '\033[32m'
@@ -85,6 +87,7 @@ def get_callstack():
85
87
  return callstack
86
88
 
87
89
 
90
+ @recursion_depth_decorator("data_to_cpu")
88
91
  def data_to_cpu(data, deep, data_cpu):
89
92
  global cpu_device
90
93
  list_cpu = []
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
45
45
 
46
46
  @catch_exception
47
47
  def default(self, line=""):
48
- self.util.execute_command(line)
49
- return False
50
-
51
- @catch_exception
52
- def do_run(self, line=""):
53
- self.util.execute_command(line)
48
+ self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
54
49
 
55
50
  @catch_exception
56
51
  def do_vc(self, line=""):
@@ -13,12 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import hashlib
17
16
  import os
18
17
  import re
19
18
  import subprocess
20
19
  import sys
21
20
  import time
21
+ import zlib
22
22
  from collections import namedtuple
23
23
 
24
24
  import numpy as np
@@ -114,11 +114,12 @@ class Util:
114
114
  @staticmethod
115
115
  def get_md5_for_numpy(obj):
116
116
  np_bytes = obj.tobytes()
117
- md5_hash = hashlib.md5(np_bytes)
118
- return md5_hash.hexdigest()
117
+ md5_crc = zlib.crc32(np_bytes)
118
+ return f"{md5_crc:08x}"
119
119
 
120
120
  @staticmethod
121
121
  def deal_with_dir_or_file_inconsistency(output_path):
122
+ logger.warning(f"Trying to delete {output_path}")
122
123
  remove_path(output_path)
123
124
  raise ParseException("Inconsistent directory structure or file.")
124
125
 
@@ -264,7 +265,7 @@ class Util:
264
265
  match = re_pattern.match(name)
265
266
  if not match:
266
267
  continue
267
- if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
268
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
268
269
  continue
269
270
  file_list[name] = gen_info_func(name, match, file["root"])
270
271
  return file_list
@@ -16,9 +16,9 @@
16
16
  import os
17
17
  import re
18
18
 
19
- from msprobe.core.common.const import Const
19
+ from msprobe.core.common.const import Const, FileCheckConst
20
20
  from msprobe.core.common.exceptions import MsprobeException
21
- from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
21
+ from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, FileChecker
22
22
  from msprobe.core.common.log import logger
23
23
  from msprobe.core.common.utils import is_int
24
24
  from msprobe.core.common_config import BaseConfig, CommonConfig
@@ -42,6 +42,7 @@ class TensorConfig(BaseConfig):
42
42
  self.tls_path = json_config.get("tls_path", "./")
43
43
  self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
44
44
  self.check_config()
45
+ self._check_summary_mode()
45
46
  self._check_file_format()
46
47
  if self.online_run_ut:
47
48
  self._check_online_run_ut()
@@ -65,7 +66,10 @@ class TensorConfig(BaseConfig):
65
66
  check_file_or_directory_path(self.tls_path, isdir=True)
66
67
  check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
67
68
  check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
68
- check_crt_valid(os.path.join(self.tls_path, "client.crt"))
69
+ check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
+ crl_path = os.path.join(self.tls_path, "crl.pem")
71
+ if os.path.exists(crl_path):
72
+ check_file_or_directory_path(crl_path)
69
73
 
70
74
  if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
71
75
  raise Exception(f"host: {self.host} is invalid.")
@@ -80,9 +84,8 @@ class StatisticsConfig(BaseConfig):
80
84
  self.check_config()
81
85
  self._check_summary_mode()
82
86
 
83
- def _check_summary_mode(self):
84
- if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
85
- raise Exception("summary_mode is invalid")
87
+ self.tensor_list = json_config.get("tensor_list", [])
88
+ self._check_str_list_config(self.tensor_list, "tensor_list")
86
89
 
87
90
 
88
91
  class OverflowCheckConfig(BaseConfig):
@@ -95,6 +98,8 @@ class OverflowCheckConfig(BaseConfig):
95
98
  def check_overflow_config(self):
96
99
  if self.overflow_nums is not None and not is_int(self.overflow_nums):
97
100
  raise Exception("overflow_num is invalid")
101
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
102
+ raise Exception("overflow_nums should be -1 or positive integer")
98
103
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
99
104
  raise Exception("check_mode is invalid")
100
105
 
@@ -148,7 +153,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
148
153
  self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
149
154
  ):
150
155
  msg = (
151
- f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
156
+ f"You need to and can only set fuzz_device as {DeviceType.CPU} "
152
157
  f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
153
158
  )
154
159
  logger.error_log_with_exp(
@@ -271,13 +276,13 @@ class RunUTConfig(BaseConfig):
271
276
 
272
277
  @classmethod
273
278
  def check_nfs_path_config(cls, nfs_path):
274
- if nfs_path and not os.path.exists(nfs_path):
275
- raise Exception("nfs_path: %s does not exist" % nfs_path)
279
+ if nfs_path:
280
+ FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
276
281
 
277
282
  @classmethod
278
283
  def check_tls_path_config(cls, tls_path):
279
- if tls_path and not os.path.exists(tls_path):
280
- raise Exception("tls_path: %s does not exist" % tls_path)
284
+ if tls_path:
285
+ FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
281
286
 
282
287
  def check_run_ut_config(self):
283
288
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.common.utils import Const
17
+ from msprobe.core.service import BaseService
18
+ from msprobe.pytorch.attl_manager import ATTLManager
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
21
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
22
+ from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate
23
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
+ from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func
25
+ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
26
+ from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
27
+
28
+ if torch_version_above_or_equal_2:
29
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
30
+
31
+
32
+ class PytorchService(BaseService):
33
+ @property
34
+ def _get_framework_type(self):
35
+ return Const.PT_FRAMEWORK
36
+
37
+ @staticmethod
38
+ def _get_current_rank():
39
+ return get_rank_if_initialized()
40
+
41
+ def _init_specific_components(self):
42
+ self.logger = logger
43
+ self.api_register = get_api_register()
44
+ self.module_processor = ModuleProcesser(self.data_collector.scope)
45
+ self.attl_manager = ATTLManager(self.config)
46
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
47
+ self.api_template = ApiTemplate
48
+
49
+ def _register_hook(self):
50
+ self.attl_manager.attl_init()
51
+ if self._is_mix_level:
52
+ register_optimizer_hook(self.data_collector)
53
+
54
+ def _register_api_hook(self):
55
+ super()._register_api_hook()
56
+ wrap_jit_script_func()
57
+
58
+ def _register_module_hook(self):
59
+ ModuleProcesser.enable_module_dump = True
60
+ self.module_processor.register_module_hook(self.model, self.build_hook)
61
+ self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
62
+
63
+ def _run_ut_dispatch(self, status):
64
+ if torch_version_above_or_equal_2:
65
+ run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
66
+
67
+ def _reset_status(self):
68
+ super()._reset_status()
69
+ ModuleProcesser.reset_module_stats()
70
+ HOOKModule.reset_module_stats()
@@ -14,21 +14,23 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import re
17
+ from dataclasses import dataclass
17
18
 
18
19
  from msprobe.core.common.const import Const
19
- from msprobe.core.common.file_utils import load_json
20
+ from msprobe.core.common.file_utils import load_json, save_json
21
+ from msprobe.core.common.utils import load_stack_json
20
22
  from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
23
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
22
24
  from msprobe.visualization.graph.graph import Graph
23
25
  from msprobe.visualization.graph.node_op import NodeOp
24
- from msprobe.visualization.utils import save_json_file, GraphConst
26
+ from msprobe.visualization.utils import GraphConst
25
27
 
26
28
 
27
29
  class GraphBuilder:
28
30
  backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
31
  forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
30
- # 匹配以大写字母开头,后接任意字母,并以Template(结尾
31
- template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
32
+ # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
33
+ template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
32
34
 
33
35
  @staticmethod
34
36
  def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
@@ -44,13 +46,14 @@ class GraphBuilder:
44
46
  """
45
47
  construct_dict = load_json(construct_path)
46
48
  dump_dict = load_json(data_path)
47
- stack_dict = load_json(stack_path)
49
+ stack_dict = load_stack_json(stack_path)
48
50
  if not complete_stack:
49
51
  GraphBuilder._simplify_stack(stack_dict)
50
52
  data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
51
53
  graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
52
54
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
53
55
  GraphBuilder._collect_apis_between_modules(graph)
56
+ GraphBuilder._add_parameters_grad(graph, data_dict)
54
57
  return graph
55
58
 
56
59
  @staticmethod
@@ -60,10 +63,10 @@ class GraphBuilder:
60
63
  """
61
64
  result = {}
62
65
  if config.graph_b:
63
- result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
64
- result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
66
+ result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
67
+ result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
65
68
  else:
66
- result = config.graph_n.to_dict()
69
+ result = config.graph_n.to_dict(config.compare_mode)
67
70
  if config.tool_tip:
68
71
  result[GraphConst.JSON_TIP_KEY] = config.tool_tip
69
72
  if config.node_colors:
@@ -73,7 +76,7 @@ class GraphBuilder:
73
76
  if config.task:
74
77
  result[GraphConst.JSON_TASK_KEY] = config.task
75
78
  result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
76
- save_json_file(filename, result)
79
+ save_json(filename, result, indent=4)
77
80
 
78
81
  @staticmethod
79
82
  def _simplify_stack(stack_dict):
@@ -186,6 +189,8 @@ class GraphBuilder:
186
189
  # 数据格式:"output": [[{param1}, {param2}, ...]]
187
190
  if GraphBuilder._is_valid_batch_p2p_output(param_list):
188
191
  for param in param_list[0]:
192
+ if not isinstance(param, dict):
193
+ continue
189
194
  info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
190
195
  GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
191
196
  node.batch_p2p_info.append(info)
@@ -235,10 +240,46 @@ class GraphBuilder:
235
240
 
236
241
  graph.root.subnodes = output
237
242
 
243
+ @staticmethod
244
+ def _add_parameters_grad(graph, data_dict):
245
+ """
246
+ 将parameters_grad信息添加到graph中,
247
+ 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
248
+
249
+ 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
250
+ 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
251
+ """
252
+ prefixes = []
253
+ suffix = Const.SEP + Const.PARAMS_GRAD
254
+ for node_id in data_dict.keys():
255
+ if node_id not in graph.node_map and node_id.endswith(suffix):
256
+ prefixes.append(node_id.replace(suffix, ''))
257
+
258
+ max_info = {prefix: 0 for prefix in prefixes}
259
+
260
+ for key in graph.node_map.keys():
261
+ parts = key.split(Const.SEP)
262
+ if len(parts) > 2 and parts[-2] == Const.BACKWARD:
263
+ num = int(parts[-1])
264
+ prefix = Const.SEP.join(parts[:-2])
265
+ if prefix in max_info and num > max_info[prefix]:
266
+ max_info[prefix] = num
267
+
268
+ for prefix, num in max_info.items():
269
+ node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
270
+ node = graph.get_node(node_id)
271
+ if node:
272
+ parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
273
+ # 添加输入输出数据
274
+ node_data = data_dict.get(parameters_grad_node_id, {})
275
+ input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
276
+ # 更新数据
277
+ graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
278
+
238
279
 
239
280
  class GraphExportConfig:
240
281
  def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
241
- overflow_check=False):
282
+ overflow_check=False, compare_mode=None):
242
283
  self.graph_n = graph_n
243
284
  self.graph_b = graph_b
244
285
  self.tool_tip = tool_tip
@@ -246,3 +287,21 @@ class GraphExportConfig:
246
287
  self.micro_steps = micro_steps
247
288
  self.task = task
248
289
  self.overflow_check = overflow_check
290
+ self.compare_mode = compare_mode
291
+
292
+
293
+ @dataclass
294
+ class GraphInfo:
295
+ graph: Graph
296
+ construct_path: str
297
+ data_path: str
298
+ stack_path: str
299
+
300
+
301
+ @dataclass
302
+ class BuildGraphTaskInfo:
303
+ graph_info_n: GraphInfo
304
+ graph_info_b: GraphInfo
305
+ npu_rank: str
306
+ bench_rank: str
307
+ time_str: str