mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
52
  from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
56
54
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
57
55
  ExecParams
58
56
 
@@ -90,27 +88,22 @@ seed_all()
90
88
 
91
89
  def run_ut(config):
92
90
  logger.info("start UT test")
93
- if config.online_config.is_online:
94
- logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
95
- logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
96
- else:
97
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
98
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
91
+
92
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
93
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
99
94
 
100
95
  if config.save_error_data:
101
96
  logger.info(f"UT task error_data will be saved in {config.error_data_path}")
102
97
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
103
98
 
104
- if config.online_config.is_online:
105
- run_api_online(config, compare)
106
- else:
107
- csv_df = read_csv(config.result_csv_path)
108
- try:
109
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
110
- except IndexError:
111
- logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
112
- api_name_set = set()
113
- run_api_offline(config, compare, api_name_set)
99
+
100
+ csv_df = read_csv(config.result_csv_path)
101
+ try:
102
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
+ except IndexError:
104
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
105
+ api_name_set = set()
106
+ run_api_offline(config, compare, api_name_set)
114
107
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
115
108
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
116
109
  change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
164
157
  gc.collect()
165
158
 
166
159
 
167
- def run_api_online(config, compare):
168
- attl = init_attl(config.online_config)
169
- dispatcher = ConsumerDispatcher(compare=compare)
170
- dispatcher.start(handle_func=run_torch_api_online, config=config)
171
-
172
- def tcp_communication_flow():
173
- while True:
174
- api_data = attl.recv()
175
- if api_data == 'STOP_':
176
- continue
177
- if api_data == 'KILL_':
178
- time.sleep(1)
179
- logger.info("==========接收到STOP信号==========")
180
- dispatcher.stop()
181
- attl.stop_serve()
182
- time.sleep(1)
183
- break
184
- if not isinstance(api_data, ApiData):
185
- continue
186
- api_full_name = api_data.name
187
- _, api_name = extract_basic_api_segments(api_full_name)
188
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
189
- continue
190
- if api_data.rank in config.online_config.rank_list:
191
- dispatcher.update_consume_queue(api_data)
192
-
193
- def shared_storage_communication_flow():
194
- flag_num = -1
195
- while True:
196
- api_data = attl.download()
197
- if api_data == "start":
198
- if flag_num == -1:
199
- flag_num += 1
200
- flag_num += 1
201
- if api_data == "end":
202
- flag_num -= 1
203
- if flag_num == 0:
204
- dispatcher.stop()
205
- break
206
- if not isinstance(api_data, ApiData):
207
- continue
208
- api_full_name = api_data.name
209
- _, api_name = extract_basic_api_segments(api_full_name)
210
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
211
- continue
212
- if api_data.rank in config.online_config.rank_list:
213
- dispatcher.update_consume_queue(api_data)
214
-
215
- if config.online_config.nfs_path:
216
- shared_storage_communication_flow()
217
- else:
218
- tcp_communication_flow()
219
-
220
-
221
160
  def blacklist_and_whitelist_filter(api_name, black_list, white_list):
222
161
  """
223
162
  run api(api_name) if api_name not in black_list and in white_list.
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
315
254
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
316
255
 
317
256
 
318
- def run_torch_api_online(api_full_name, api_data, backward_content):
319
- in_fwd_data_list = []
320
- api_type, api_name = extract_basic_api_segments(api_full_name)
321
- args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
322
- in_fwd_data_list.append(args)
323
- in_fwd_data_list.append(kwargs)
324
- if kwargs.get("device"):
325
- del kwargs["device"]
326
-
327
- device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
328
- device_out = exec_api(device_exec_params)
329
- device_out = move2device_exec(device_out, "cpu")
330
- return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
331
-
332
-
333
257
  def check_need_grad(api_info_dict):
334
258
  need_grad = True
335
259
  if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
389
313
  return error_data_path
390
314
 
391
315
 
392
- def init_attl(config):
393
- """config: OnlineConfig"""
394
- attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
395
- connect_ip=config.host,
396
- connect_port=config.port,
397
- nfs_path=config.nfs_path,
398
- tls_path=config.tls_path))
399
- return attl
400
-
401
-
402
316
  def _run_ut_parser(parser):
403
317
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
404
318
  help="<Optional> The api param tool result file: generate from api param tool, "
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
481
395
  _run_ut_parser(parser)
482
396
  args = parser.parse_args(sys.argv[1:])
483
397
  run_ut_command(args)
484
-
485
-
486
- def checked_online_config(online_config):
487
- if not online_config.is_online:
488
- return
489
- if not isinstance(online_config.is_online, bool):
490
- raise ValueError("is_online must be bool type")
491
- # rank_list
492
- if not isinstance(online_config.rank_list, list):
493
- raise ValueError("rank_list must be a list")
494
- if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
495
- raise ValueError("All elements in rank_list must be integers")
496
-
497
- # nfs_path
498
- if online_config.nfs_path:
499
- check_file_or_directory_path(online_config.nfs_path, isdir=True)
500
- return
501
- # tls_path
502
- if online_config.tls_path:
503
- check_file_or_directory_path(online_config.tls_path, isdir=True)
504
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
505
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
506
- check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
507
- crl_path = os.path.join(online_config.tls_path, "crl.pem")
508
- if os.path.exists(crl_path):
509
- check_file_or_directory_path(crl_path)
510
-
511
- # host and port
512
- if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
513
- raise Exception(f"host: {online_config.host} is invalid.")
514
- if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
515
- raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
516
398
 
517
399
 
518
400
  def run_ut_command(args):
@@ -525,7 +407,7 @@ def run_ut_command(args):
525
407
  else:
526
408
  checker_config = CheckerConfig()
527
409
 
528
- if not checker_config.is_online and not args.api_info_file:
410
+ if not args.api_info_file:
529
411
  logger.error("Please provide api_info_file for offline run ut.")
530
412
  raise Exception("Please provide api_info_file for offline run ut.")
531
413
 
@@ -588,8 +470,6 @@ def run_ut_command(args):
588
470
  global UT_ERROR_DATA_DIR
589
471
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
590
472
  error_data_path = initialize_save_error_data(error_data_path)
591
- online_config = checker_config.get_online_config()
592
- checked_online_config(online_config)
593
473
  config_params = {
594
474
  'forward_content': forward_content,
595
475
  'backward_content': backward_content,
@@ -150,7 +150,7 @@ def remove_dropout():
150
150
  F.dropout3d = function_dropout3d
151
151
 
152
152
 
153
- def seed_all(seed=1234, mode=False, rm_dropout=True):
153
+ def seed_all(seed=1234, mode=False, rm_dropout=False):
154
154
  check_seed_all(seed, mode, rm_dropout)
155
155
  try:
156
156
  random.seed(seed)
@@ -388,26 +388,6 @@ def load_pkl(pt_path):
388
388
  return pt
389
389
 
390
390
 
391
- def save_api_data(api_data):
392
- """Save data to io stream"""
393
- try:
394
- io_buff = io.BytesIO()
395
- torch.save(api_data, io_buff)
396
- except Exception as e:
397
- raise RuntimeError("save api_data to io_buff failed") from e
398
- return io_buff
399
-
400
-
401
- def load_api_data(api_data_bytes):
402
- """Load data from bytes stream"""
403
- try:
404
- buffer = io.BytesIO(api_data_bytes)
405
- buffer = torch.load(buffer, map_location="cpu", weights_only=False)
406
- except Exception as e:
407
- raise RuntimeError("load api_data from bytes failed") from e
408
- return buffer
409
-
410
-
411
391
  def is_recomputation():
412
392
  """Check if the current operation is in the re-computation phase.
413
393
 
@@ -31,8 +31,16 @@ def compare(input_param, output_path, **kwargs):
31
31
  raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
32
32
  config = setup_comparison(input_param, output_path, **kwargs)
33
33
 
34
- mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
35
- config.dump_mode, config.compared_file_type)
34
+ config_dict = {
35
+ 'stack_mode': config.stack_mode,
36
+ 'auto_analyze': config.auto_analyze,
37
+ 'fuzzy_match': config.fuzzy_match,
38
+ 'highlight': config.highlight,
39
+ 'dump_mode': config.dump_mode,
40
+ 'first_diff_analyze': config.first_diff_analyze,
41
+ 'compared_file_type': config.compared_file_type
42
+ }
43
+ mode_config = ModeConfig(**config_dict)
36
44
  mapping_config = MappingConfig(data_mapping=config.data_mapping)
37
45
  pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
38
46
  pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
@@ -13,21 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import torch
17
16
 
18
- from msprobe.pytorch.hook_module.api_register import get_api_register
17
+ from msprobe.pytorch.compare.distributed_compare import compare_distributed
19
18
 
20
19
 
21
- def wrap_jit_script_func():
22
- def patched_script(*args, **kwargs):
23
- all_api_registered = api_register.all_api_registered
24
- if all_api_registered:
25
- api_register.restore_all_api()
26
- result = original_script(*args, **kwargs)
27
- if all_api_registered:
28
- api_register.register_all_api()
29
- return result
30
-
31
- original_script = torch.jit.script
32
- api_register = get_api_register()
33
- torch.jit.script = patched_script
20
+ def pt_diff_analyze(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze):
21
+ compare_distributed(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze=first_diff_analyze)
@@ -35,7 +35,8 @@ def read_pt_data(dir_path, file_name):
35
35
  data_value = load_pt(data_path, to_cpu=True).detach()
36
36
  except RuntimeError as e:
37
37
  # 这里捕获 load_pt 中抛出的异常
38
- logger.error(f"Failed to load the .pt file at {data_path}.")
38
+ data_path_file_name = os.path.basename(data_path)
39
+ logger.error(f"Failed to load the .pt file at {data_path_file_name}.")
39
40
  raise CompareException(CompareException.INVALID_FILE_ERROR) from e
40
41
  except AttributeError as e:
41
42
  # 这里捕获 detach 方法抛出的异常
@@ -34,6 +34,7 @@ class DebuggerConfig:
34
34
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
35
35
  self.framework = Const.PT_FRAMEWORK
36
36
  self.async_dump = common_config.async_dump if common_config.async_dump else False
37
+ self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW
37
38
 
38
39
  if self.task == Const.FREE_BENCHMARK:
39
40
  self.fuzz_device = task_config.fuzz_device
@@ -47,16 +48,6 @@ class DebuggerConfig:
47
48
  "max_sample": task_config.max_sample
48
49
  }
49
50
 
50
- self.online_run_ut = False
51
- if self.task == Const.TENSOR:
52
- # dump api tensor and collaborate with online run_ut
53
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
54
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
55
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
56
- self.host = task_config.host if task_config.host else ""
57
- self.port = task_config.port if task_config.port else -1
58
- self.online_run_ut_recompute = task_config.online_run_ut_recompute \
59
- if isinstance(task_config.online_run_ut_recompute, bool) else False
60
51
 
61
52
  self.check()
62
53
  self._check_statistics_config(task_config)
@@ -65,7 +56,7 @@ class DebuggerConfig:
65
56
  self.is_backward_kernel_dump = False
66
57
  self._check_and_adjust_config_with_l2()
67
58
 
68
- def check_kwargs(self):
59
+ def check(self):
69
60
  if self.task and self.task not in Const.TASK_LIST:
70
61
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
71
62
  f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
@@ -78,22 +69,26 @@ class DebuggerConfig:
78
69
  if not isinstance(self.async_dump, bool):
79
70
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
71
  f"The parameters async_dump should be bool.")
81
- if self.async_dump and self.task == Const.TENSOR:
82
- if self.level == Const.LEVEL_DEBUG:
83
- self.list = [] # async_dump + debug level case ignore list
84
- if not self.list and self.level != Const.LEVEL_DEBUG:
85
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
86
- f"The parameters async_dump is true in tensor task, the parameters list cannot be "
87
- f"empty.")
88
72
  if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
89
73
  logger.warning_on_rank_0(
90
74
  f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
91
75
  f"If not, the default level is {Const.LEVEL_MIX}."
92
76
  )
93
77
  self.level = Const.LEVEL_MIX
94
-
95
- def check(self):
96
- self.check_kwargs()
78
+ if self.async_dump:
79
+ if self.task == Const.TENSOR:
80
+ if self.level == Const.LEVEL_DEBUG:
81
+ self.list = [] # async_dump + debug level case ignore list
82
+ if not self.list and self.level != Const.LEVEL_DEBUG:
83
+ raise MsprobeException(
84
+ MsprobeException.INVALID_PARAM_ERROR,
85
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be empty."
86
+ )
87
+ if self.summary_mode == Const.MD5:
88
+ raise MsprobeException(
89
+ MsprobeException.INVALID_PARAM_ERROR,
90
+ f"The parameters async_dump is true, the parameters summary_mode cannot be md5."
91
+ )
97
92
  return True
98
93
 
99
94
  def check_model(self, instance, start_model, token_range=None):
@@ -102,7 +97,7 @@ class DebuggerConfig:
102
97
  if token_range and not instance.model:
103
98
  error_info = "The 'model' parameter must be provided when token_range is not None"
104
99
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
105
-
100
+
106
101
  if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
107
102
  return
108
103
 
@@ -123,7 +118,7 @@ class DebuggerConfig:
123
118
  break
124
119
  if error_model is not None:
125
120
  error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
126
- f"type, currently there is an unsupported {type(error_model)} type.")
121
+ f"type, currently there is an unsupported {type(error_model)} type.")
127
122
  raise MsprobeException(
128
123
  MsprobeException.INVALID_PARAM_ERROR, error_info)
129
124
  else:
@@ -24,8 +24,11 @@ from msprobe.pytorch.common.log import logger
24
24
 
25
25
 
26
26
  def wrap_setup_backward_hook(func):
27
- def requires_clone(tensor):
28
- return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
27
+ def requires_clone(tensor, need_check_leaf=False):
28
+ need_clone = isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
29
+ if need_check_leaf:
30
+ need_clone &= tensor.grad_fn is not None
31
+ return need_clone
29
32
 
30
33
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
31
34
  def parse_tensor(item, tensor_list):
@@ -39,20 +42,20 @@ def wrap_setup_backward_hook(func):
39
42
  parse_tensor(value, tensor_list)
40
43
 
41
44
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
42
- def rebuild_args(item, tensor_iter):
43
- if requires_clone(item):
45
+ def rebuild_args(item, tensor_iter, need_check_leaf=False):
46
+ if requires_clone(item, need_check_leaf):
44
47
  result = next(tensor_iter)
45
48
  if hasattr(result, "_base") and result._base is not None:
46
49
  if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
47
50
  torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
48
- return result
51
+ return result
49
52
  if isinstance(item, list):
50
53
  for index, value in enumerate(item):
51
- item[index] = rebuild_args(value, tensor_iter)
54
+ item[index] = rebuild_args(value, tensor_iter, need_check_leaf=True)
52
55
  return item
53
56
  if isinstance(item, dict):
54
57
  for key, value in item.items():
55
- item[key] = rebuild_args(value, tensor_iter)
58
+ item[key] = rebuild_args(value, tensor_iter, need_check_leaf=True)
56
59
  return item
57
60
  if isinstance(item, tuple):
58
61
  if hasattr(item, '_fields'):
@@ -21,25 +21,18 @@ import torch
21
21
  from torch.utils.hooks import BackwardHook, RemovableHandle
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.runtime import Runtime
24
25
  from msprobe.core.common.utils import ModuleQueue, ThreadSafe
26
+ from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
25
27
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
26
28
  from msprobe.pytorch.common.log import logger
27
29
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
28
30
  from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
29
31
 
30
32
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
31
- if torch_version_above_or_equal_2:
32
- from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
33
-
34
-
35
- def checkpoint_without_early_stop(*args, **kwargs):
36
- with set_checkpoint_early_stop(False):
37
- return origin_checkpoint(*args, **kwargs)
38
-
39
-
40
- def replace_checkpoint():
41
- if torch_version_above_or_equal_2:
42
- torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
33
+ torch_version_above_or_equal_21 = torch.__version__.split('+')[0] >= '2.1'
34
+ if torch_version_above_or_equal_21:
35
+ from torch.utils.checkpoint import _StopRecomputationError
43
36
 
44
37
 
45
38
  def wrap_megatron_deallocate(func):
@@ -53,6 +46,27 @@ def wrap_megatron_deallocate(func):
53
46
  return wrapper_func
54
47
 
55
48
 
49
+ def wrap_forward_with_hook_safety(module):
50
+ """
51
+ 包装模块的forward方法,确保异常时也执行forward_hook。
52
+ """
53
+ original_forward = module.forward
54
+
55
+ def wrapped_forward(*args, **kwargs):
56
+ try:
57
+ output = original_forward(*args, **kwargs)
58
+ return output
59
+ except _StopRecomputationError as e:
60
+ exception_output = None
61
+ if len(module._forward_hooks.values()) > 0:
62
+ # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook
63
+ hook_fn = list(module._forward_hooks.values())[0]
64
+ hook_fn(module, args, kwargs, exception_output)
65
+ raise e
66
+ if torch_version_above_or_equal_21:
67
+ module.forward = wrapped_forward
68
+
69
+
56
70
  class ModuleProcesser:
57
71
  module_queue = ModuleQueue()
58
72
  module_count = {}
@@ -66,11 +80,12 @@ class ModuleProcesser:
66
80
  def __init__(self, scope):
67
81
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
68
82
  wrap_setup_input_output_hook()
69
- replace_checkpoint()
70
83
  try:
71
84
  from megatron.core.pipeline_parallel import schedules
72
85
  origin_func_id = id(schedules.deallocate_output_tensor)
73
86
  schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
87
+ schedules.forward_step = wrap_megatron_step(schedules.forward_step)
88
+ schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
74
89
  for module in list(sys.modules.values()):
75
90
  if module.__name__ == 'schedules':
76
91
  continue
@@ -155,6 +170,7 @@ class ModuleProcesser:
155
170
  f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
156
171
  )
157
172
  ModuleProcesser.module_with_backward_hook[prefix_name] = True
173
+ wrap_forward_with_hook_safety(module)
158
174
  register_forward_pre_hook(module, forward_pre_hook)
159
175
 
160
176
  def build_module_hook(self, module_name, build_data_hook):
@@ -163,6 +179,9 @@ class ModuleProcesser:
163
179
  if kwargs is None:
164
180
  kwargs = {}
165
181
 
182
+ if not Runtime.is_running:
183
+ return (args, kwargs) if torch_version_above_or_equal_2 else args
184
+
166
185
  if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
167
186
  return (args, kwargs) if torch_version_above_or_equal_2 else args
168
187
 
@@ -243,14 +262,16 @@ class ModuleProcesser:
243
262
  ModuleProcesser.module_stack[tid] = []
244
263
 
245
264
  if self.module_stack[tid]:
246
- ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
265
+ ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1] if not is_megatron() \
266
+ else [self.module_stack[tid][-1], get_micro_step()]
247
267
  else:
248
268
  parent_name = ModuleProcesser.module_queue.find_last(full_name)
249
- ModuleProcesser.module_node[full_name] = parent_name
269
+ ModuleProcesser.module_node[full_name] = parent_name if not is_megatron() \
270
+ else [parent_name, get_micro_step()]
250
271
 
251
272
  ModuleProcesser.module_queue.add_name(full_name)
252
273
  ModuleProcesser.module_stack[tid].append(full_name)
253
- ModuleProcesser.api_parent_node[tid] = full_name
274
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
254
275
  if self.scope:
255
276
  self.scope.begin_module(full_name)
256
277
 
@@ -258,14 +279,15 @@ class ModuleProcesser:
258
279
  tid = threading.get_ident()
259
280
  if torch_version_above_or_equal_2 or is_forward:
260
281
  ModuleProcesser.module_queue.remove_name(full_name)
261
- ModuleProcesser.api_parent_node[tid] = None
282
+ ModuleProcesser.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
262
283
  if self.module_stack.get(tid):
263
284
  ModuleProcesser.module_stack[tid].pop()
264
285
  if self.module_stack.get(tid):
265
- ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
286
+ ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1] if not is_megatron() \
287
+ else [ModuleProcesser.module_stack[tid][-1], get_micro_step()]
266
288
  if self.scope:
267
289
  self.scope.end_module(full_name)
268
290
  else:
269
291
  if self.scope:
270
292
  self.scope.begin_module(full_name)
271
- ModuleProcesser.api_parent_node[tid] = full_name
293
+ ModuleProcesser.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
@@ -17,8 +17,8 @@ from abc import ABC
17
17
 
18
18
  import torch
19
19
  from msprobe.core.common.const import Const
20
+ from msprobe.core.common.utils import replace_last_occurrence
20
21
  from msprobe.pytorch.free_benchmark import logger
21
- from msprobe.pytorch.free_benchmark.common.constant import CommonField
22
22
  from msprobe.pytorch.free_benchmark.common.enums import (
23
23
  DeviceType,
24
24
  FuzzLevel,
@@ -37,6 +37,7 @@ from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
37
37
 
38
38
 
39
39
  class FreeBenchmarkCheck(ABC):
40
+ grad_saver_dict = {}
40
41
 
41
42
  def __init__(self, config) -> None:
42
43
  super().__init__()
@@ -68,7 +69,9 @@ class FreeBenchmarkCheck(ABC):
68
69
  grad_saver.kwargs = kwargs
69
70
  grad_saver.register_compare_func_for_inputs(args, data_processor)
70
71
  grad_saver.cache_backward_input(args)
71
- setattr(module, CommonField.GRADSAVER, grad_saver)
72
+
73
+ backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD)
74
+ FreeBenchmarkCheck.grad_saver_dict[backward_name] = grad_saver
72
75
 
73
76
  def forward(self, name, module, args, kwargs, output):
74
77
  if not self.config.fuzz_stage == Const.FORWARD:
@@ -92,16 +95,16 @@ class FreeBenchmarkCheck(ABC):
92
95
  return perturbed_output, handler.get_unequal_rows()
93
96
 
94
97
  def backward(self, name, module, grad_output):
95
-
96
98
  if not self.config.fuzz_stage == Const.BACKWARD:
97
99
  return
98
100
  try:
99
- grad_saver = getattr(module, CommonField.GRADSAVER)
101
+ grad_saver = FreeBenchmarkCheck.grad_saver_dict[name]
100
102
  except AttributeError:
101
103
  logger.warning_on_rank_0(
102
104
  f"[msprobe] Free benchmark: get grad saver failed. api_name:{name}"
103
105
  )
104
106
  return
107
+ del FreeBenchmarkCheck.grad_saver_dict[name]
105
108
 
106
109
  _new_grad_output = grad_output
107
110
  try: