mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,75 +13,28 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import torch
17
- from msprobe.core.common.const import Const
18
- from msprobe.core.data_dump.scope import BaseScope
19
16
  from msprobe.pytorch.common.log import logger
17
+ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
20
18
  from msprobe.pytorch.hook_module.api_register import get_api_register
21
19
 
22
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
-
24
20
 
25
21
  class ModuleDumper:
26
22
  def __init__(self, service):
27
23
  self.service = service
28
- self.hook_handle_list = []
29
24
  self.api_register = get_api_register()
30
25
 
31
26
  def start_module_dump(self, module, dump_name):
27
+ if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'):
28
+ logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
29
+ return
30
+
31
+ ModuleProcesser.enable_module_dump = True
32
32
  self.api_register.restore_all_api()
33
- self.register_hook(module, dump_name)
33
+ if not hasattr(module, 'msprobe_module_dump'):
34
+ self.service.module_processor.register_module_hook(module, self.service.build_hook,
35
+ recursive=False, module_names=[dump_name])
36
+ setattr(module, 'msprobe_module_dump', True)
34
37
 
35
38
  def stop_module_dump(self):
39
+ ModuleProcesser.enable_module_dump = False
36
40
  self.api_register.register_all_api()
37
- for hook_handle in self.hook_handle_list:
38
- if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
39
- hook_handle.remove()
40
- self.hook_handle_list.clear()
41
-
42
- def register_hook(self, module, dump_name):
43
- prefix_name = (
44
- BaseScope.Module_Type_Module + Const.SEP +
45
- dump_name + Const.SEP +
46
- module.__class__.__name__ + Const.SEP
47
- )
48
- module_processor = self.service.module_processor
49
- _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
50
- BaseScope.Module_Type_Module,
51
- prefix_name
52
- )
53
-
54
- if module_processor.has_register_backward_hook(module):
55
- logger.warning(
56
- f"The {dump_name} module has registered deprecated register_backward_hook,"
57
- f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
58
- )
59
- if torch_version_above_or_equal_2:
60
- forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
61
- else:
62
- if not module_processor.has_register_backward_hook(module):
63
- backward_hook_handle = module.register_full_backward_hook(
64
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
65
- )
66
- self.hook_handle_list.append(backward_hook_handle)
67
- forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
68
- self.hook_handle_list.append(forward_hook_handle)
69
- if not module_processor.has_register_backward_hook(module):
70
- backward_hook_handle = module.register_full_backward_hook(backward_hook)
71
- self.hook_handle_list.append(backward_hook_handle)
72
-
73
- forward_pre_hook_handle = module.register_forward_pre_hook(
74
- module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
75
- )
76
- forward_hook_handle = module.register_forward_hook(
77
- module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
78
- )
79
- self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
80
- if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
81
- backward_pre_hook_handle = module.register_full_backward_pre_hook(
82
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
83
- )
84
- backward_hook_handle = module.register_full_backward_hook(
85
- module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
86
- )
87
- self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
@@ -13,16 +13,16 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from functools import wraps
16
+ from collections import OrderedDict
17
17
 
18
18
  import torch
19
- from torch.utils.hooks import BackwardHook
19
+ from torch.utils.hooks import BackwardHook, RemovableHandle
20
20
 
21
21
  from msprobe.core.common.const import Const
22
- from msprobe.core.common.decorator import recursion_depth_decorator
23
22
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
24
23
  from msprobe.pytorch.common.log import logger
25
- from msprobe.pytorch.common.utils import replace_last_occurrence, is_float8_tensor
24
+ from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
25
+ from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
26
26
 
27
27
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
28
28
  if torch_version_above_or_equal_2:
@@ -39,43 +39,40 @@ def replace_checkpoint():
39
39
  torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
40
40
 
41
41
 
42
+ def wrap_megatron_deallocate(func):
43
+ def wrapper_func(out, deallocate_pipeline_outputs=False):
44
+ if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None:
45
+ out_clone = out.clone()
46
+ out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
47
+ return func(out_clone, deallocate_pipeline_outputs)
48
+ return func(out, deallocate_pipeline_outputs)
49
+ return wrapper_func
50
+
51
+
42
52
  class ModuleProcesser:
43
53
  module_count = {}
44
54
  module_stack = []
45
55
  api_parent_node = ""
46
56
  module_node = {}
57
+ module_bw_hook_kernels = {}
58
+ module_with_backward_hook = {}
59
+ enable_module_dump = False
47
60
 
48
61
  def __init__(self, scope):
49
62
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
50
- BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
51
- BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
63
+ wrap_setup_input_output_hook()
52
64
  replace_checkpoint()
65
+ try:
66
+ from megatron.core.pipeline_parallel import schedules
67
+ schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
68
+ logger.info_on_rank_0("Patch megatron method success.")
69
+ except ImportError:
70
+ logger.info_on_rank_0("No megatron find.")
71
+ except Exception as e:
72
+ logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
53
73
 
54
74
  @staticmethod
55
- def clone_return_value(func):
56
- @wraps(func)
57
- def clone_return_value_func(*args, **kwargs):
58
- result = func(*args, **kwargs)
59
- return ModuleProcesser.clone_if_tensor(result)
60
-
61
- return clone_return_value_func
62
-
63
- @staticmethod
64
- @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
65
- def clone_if_tensor(result):
66
- if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
67
- return result.clone()
68
- elif type(result) is tuple:
69
- return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
70
- elif type(result) is list:
71
- return list(ModuleProcesser.clone_if_tensor(x) for x in result)
72
- elif type(result) is dict:
73
- return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
74
- else:
75
- return result
76
-
77
- @staticmethod
78
- def module_count_func(module_name):
75
+ def set_and_get_calls_number(module_name):
79
76
  if module_name not in ModuleProcesser.module_count:
80
77
  ModuleProcesser.module_count[module_name] = 0
81
78
  else:
@@ -89,13 +86,19 @@ class ModuleProcesser:
89
86
  module._is_full_backward_hook is False
90
87
 
91
88
  @staticmethod
92
- def get_modules_and_names(models):
89
+ def get_modules_and_names(models, recursive, module_names):
93
90
  modules_and_names_with_index = {}
94
91
  if isinstance(models, (list, tuple)):
92
+ if not recursive and len(module_names) != len(models):
93
+ return modules_and_names_with_index
95
94
  for index, model in enumerate(models):
96
- modules_and_names_with_index[str(index)] = model.named_modules()
95
+ modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \
96
+ [(module_names[index], model)]
97
97
  else:
98
- modules_and_names_with_index["-1"] = models.named_modules()
98
+ if not recursive and len(module_names) != 1:
99
+ return modules_and_names_with_index
100
+ modules_and_names_with_index["-1"] = models.named_modules() if recursive else \
101
+ [(module_names[0], models)]
99
102
  return modules_and_names_with_index
100
103
 
101
104
  @classmethod
@@ -104,107 +107,134 @@ class ModuleProcesser:
104
107
  cls.module_stack = []
105
108
  cls.api_parent_node = ""
106
109
  cls.module_node = {}
110
+ cls.module_bw_hook_kernels = {}
111
+ cls.enable_module_dump = False
107
112
 
108
- def register_module_hook(self, models, build_hook):
109
- logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
110
- modules_and_names_with_index = self.get_modules_and_names(models)
113
+ def register_module_hook(self, models, build_hook, recursive=True, module_names=None):
114
+ if module_names is None:
115
+ module_names = []
116
+
117
+ modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
111
118
  for index, modules_and_names in modules_and_names_with_index.items():
112
119
  model = models if index == "-1" else models[int(index)]
113
120
  for name, module in modules_and_names:
114
- if module == model:
121
+ if recursive and module == model:
122
+ continue
123
+ if not is_torch_nn_module(module):
124
+ logger.warning(
125
+ f"The module dump does not support {type(module)} type. "
126
+ f"The data dump for this module will be skipped."
127
+ )
115
128
  continue
116
129
  if module.__class__.__name__ == "FullyShardedDataParallel":
117
130
  continue
131
+ setattr(module, 'msprobe_hook', True)
118
132
  module_index = (index + Const.SEP) if index != "-1" else ""
119
- prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
120
- name + Const.SEP + module.__class__.__name__ + Const.SEP)
121
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
122
- BaseScope.Module_Type_Module,
123
- prefix_name
124
- )
133
+ prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
134
+ f'{module.__class__.__name__}{Const.SEP}'
135
+
136
+ forward_pre_hook = self.build_module_hook(prefix_name, build_hook)
125
137
 
126
138
  if self.has_register_backward_hook(module):
127
139
  logger.warning(
128
140
  f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
129
141
  f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
130
142
  )
143
+ ModuleProcesser.module_with_backward_hook[prefix_name] = True
144
+ register_forward_pre_hook(module, forward_pre_hook)
145
+
146
+ def build_module_hook(self, module_name, build_data_hook):
147
+ def forward_pre_hook(module, args, kwargs=None):
148
+ if kwargs is None:
149
+ kwargs = {}
150
+
151
+ if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
152
+ return (args, kwargs) if torch_version_above_or_equal_2 else args
153
+
154
+ index = ModuleProcesser.set_and_get_calls_number(module_name)
155
+ full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
156
+ full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}'
157
+
158
+ self.set_construct_info_in_pre_hook(full_forward_name)
159
+
160
+ if not hasattr(module, 'msprobe_forward_hook'):
161
+ forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict())
162
+ handle = RemovableHandle(forward_hooks_dict)
163
+ forward_hooks_dict[handle.id] = forward_hook
164
+ forward_hooks_dict.move_to_end(handle.id, last=False)
165
+ if torch_version_above_or_equal_2:
166
+ forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict())
167
+ forward_hooks_with_kwargs_dict[handle.id] = True
168
+
169
+ setattr(module, 'msprobe_forward_hook', True)
170
+
171
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
172
+
173
+ def get_backward_pre_hook(full_backward_name):
174
+ def backward_pre_hook_fn(module, grad_output):
175
+ self.set_construct_info_in_pre_hook(full_backward_name)
176
+ return backward_pre_hook_fn
177
+
178
+ def get_backward_hook(backward_data_hook, full_backward_name):
179
+ def backward_hook_fn(module, grad_input, grad_output):
180
+ new_output = backward_data_hook(module, grad_input, grad_output)
181
+ self.set_construct_info_in_hook(full_backward_name, is_forward=False)
182
+ return new_output
183
+ return backward_hook_fn
184
+
185
+ if not ModuleProcesser.module_with_backward_hook.get(module_name):
186
+ backward_pre_hook = get_backward_pre_hook(full_backward_name)
187
+ backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name)
131
188
  if torch_version_above_or_equal_2:
132
- module.register_forward_hook(forward_hook, with_kwargs=True)
189
+ bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook])
133
190
  else:
134
- if not self.has_register_backward_hook(module):
135
- module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
136
- module.register_forward_hook(forward_hook_torch_version_below_2)
137
- if not self.has_register_backward_hook(module):
138
- module.register_full_backward_hook(backward_hook)
139
-
140
- module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
141
- module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
142
- if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
143
- module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
144
- module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
145
-
146
- def node_hook(self, name_prefix, start_or_stop, **kwargs):
147
-
148
- def pre_hook(module, input, output=None):
149
- try:
150
- index = ModuleProcesser.module_count_func(name_prefix)
151
- except IndexError as e:
152
- index = None
153
- pass
154
- full_name = name_prefix + Const.SEP + str(index)
155
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
156
- module.mindstudio_reserved_name = []
157
- module.mindstudio_reserved_name.append(full_name)
158
- if self.module_stack:
159
- ModuleProcesser.module_node[full_name] = self.module_stack[-1]
191
+ bw_hook = BackwardHook(module, [backward_hook])
192
+ ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook
193
+ args = bw_hook.setup_input_hook(args)
194
+ return (args, kwargs) if torch_version_above_or_equal_2 else args
195
+
196
+ def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
197
+ if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
198
+ return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
199
+
200
+ index = ModuleProcesser.module_count.get(module_name)
201
+ full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
202
+
203
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name)
204
+ hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs)
205
+ self.set_construct_info_in_hook(full_name)
206
+
207
+ if hook_result is not None:
208
+ result = hook_result
160
209
  else:
161
- ModuleProcesser.module_node[full_name] = None
210
+ result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
162
211
 
163
- ModuleProcesser.module_stack.append(full_name)
164
- if self.module_stack:
165
- ModuleProcesser.api_parent_node = self.module_stack[-1]
166
- if self.scope:
167
- self.scope.begin_module(full_name)
212
+ bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name)
213
+ if bw_hook:
214
+ result = bw_hook.setup_output_hook(result)
168
215
 
169
- def end_hook(module, input, output=None):
216
+ return result
217
+
218
+ return forward_pre_hook
219
+
220
+ def set_construct_info_in_pre_hook(self, full_name):
221
+ if self.module_stack:
222
+ ModuleProcesser.module_node[full_name] = self.module_stack[-1]
223
+ else:
224
+ ModuleProcesser.module_node[full_name] = None
225
+ ModuleProcesser.module_stack.append(full_name)
226
+ ModuleProcesser.api_parent_node = full_name
227
+ if self.scope:
228
+ self.scope.begin_module(full_name)
229
+
230
+ def set_construct_info_in_hook(self, full_name, is_forward=True):
231
+ if torch_version_above_or_equal_2 or is_forward:
170
232
  if self.module_stack:
171
233
  ModuleProcesser.module_stack.pop()
172
- if self.module_stack:
173
- ModuleProcesser.api_parent_node = self.module_stack[-1]
174
- else:
175
- ModuleProcesser.api_parent_node = None
176
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
177
- raise RuntimeError(f"module reserve name is None when pop")
178
- current_name = module.mindstudio_reserved_name.pop()
234
+ ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None
179
235
  if self.scope:
180
- self.scope.end_module(current_name)
181
-
182
- def backward_hook(module, input, output=None):
183
- try:
184
- index = ModuleProcesser.module_count_func(name_prefix)
185
- except IndexError as e:
186
- index = None
187
- pass
188
- full_name = name_prefix + Const.SEP + str(index)
189
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
190
- module.mindstudio_reserved_name = []
191
- module.mindstudio_reserved_name.append(full_name)
192
- forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
193
- ModuleProcesser.module_node[full_name] = replace_last_occurrence(
194
- ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
195
- ModuleProcesser.api_parent_node = None
236
+ self.scope.end_module(full_name)
237
+ else:
196
238
  if self.scope:
197
239
  self.scope.begin_module(full_name)
198
-
199
- if torch_version_above_or_equal_2:
200
- if Const.START in start_or_stop:
201
- return pre_hook
202
- else:
203
- return end_hook
204
- else:
205
- if Const.FORWARD in name_prefix and Const.START in start_or_stop:
206
- return pre_hook
207
- elif Const.BACKWARD in name_prefix:
208
- return backward_hook
209
- else:
210
- return end_hook
240
+ ModuleProcesser.api_parent_node = full_name
@@ -17,6 +17,7 @@ from abc import ABC, abstractmethod
17
17
  from collections import namedtuple
18
18
  import hashlib
19
19
  from functools import wraps
20
+ import zlib
20
21
  import torch
21
22
  from msprobe.core.grad_probe.constant import GradConst
22
23
 
@@ -74,8 +75,8 @@ class CsvMd5(CsvItem):
74
75
  def generate_csv_content(csv_content_input):
75
76
  grad = csv_content_input.grad
76
77
  tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
77
- md5_hash = hashlib.md5(tensor_bytes)
78
- return [md5_hash.hexdigest()]
78
+ md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
79
+ return [md5_hash]
79
80
 
80
81
 
81
82
  @register_csv_item(GradConst.DISTRIBUTION)
@@ -15,21 +15,36 @@
15
15
 
16
16
  import functools
17
17
  import os
18
+ import inspect
18
19
 
19
20
  import torch
20
21
  import torch.distributed as dist
21
22
 
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.data_dump.api_registry import ApiRegistry
25
+ from msprobe.pytorch.common.log import logger
24
26
  from msprobe.pytorch.common.utils import (
25
27
  torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
26
28
  )
27
29
  from msprobe.pytorch.function_factory import npu_custom_functions
28
30
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
31
+ from msprobe.pytorch.hook_module.utils import dynamic_import_op
32
+ from msprobe.core.common.file_utils import load_yaml
33
+
34
+ try:
35
+ import mindspeed.ops
36
+ except ImportError:
37
+ mindspeed_enable = False
38
+ else:
39
+ mindspeed_enable = True
29
40
 
30
41
 
31
42
  torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
32
43
 
44
+ _inner_used_api = {}
45
+ _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
46
+ _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
47
+
33
48
  _api_types = {
34
49
  Const.PT_FRAMEWORK: {
35
50
  Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
@@ -57,10 +72,11 @@ if not is_gpu:
57
72
  torch_npu.distributed.distributed_c10d))
58
73
  }
59
74
  )
60
-
61
- _inner_used_api = {}
62
- _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
63
- _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
75
+ if mindspeed_enable:
76
+ _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
77
+ mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
78
+ mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
79
+ dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
64
80
 
65
81
 
66
82
  @parameter_adapter
@@ -70,7 +86,15 @@ def tensor_module_forward(module, *args, **kwargs):
70
86
 
71
87
  def dist_module_forward(module, *args, **kwargs):
72
88
  handle = module.api_func(*args, **kwargs)
73
- if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]:
89
+ try:
90
+ bound = inspect.signature(module.api_func).bind(*args, **kwargs)
91
+ bound.apply_defaults()
92
+ use_async_op_flag = bound.arguments.get("async_op", False)
93
+ except Exception as e:
94
+ use_async_op_flag = False
95
+ logger.warning(f"fail to get dist api's func signature because {e}, no wait")
96
+
97
+ if use_async_op_flag or module.api_name in ["isend", "irecv"]:
74
98
  if handle and hasattr(handle, 'wait'):
75
99
  handle.wait()
76
100
  if module.api_name == "batch_isend_irecv":
@@ -21,9 +21,8 @@ import torch
21
21
  import torch.nn as nn
22
22
  import torch.utils.hooks as full_hooks
23
23
 
24
- from msprobe.pytorch.common.utils import is_float8_tensor
25
-
26
- torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
+ from msprobe.core.common.runtime import Runtime
25
+ from msprobe.pytorch.common.utils import is_float8_tensor, register_forward_pre_hook, register_forward_hook
27
26
 
28
27
 
29
28
  class HOOKModule(nn.Module):
@@ -41,16 +40,14 @@ class HOOKModule(nn.Module):
41
40
  if not self.stop_hook:
42
41
  self.forward_data_collected = False
43
42
 
43
+ if not Runtime.is_running:
44
+ return
44
45
  prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
45
46
  if callable(hook_build_func):
46
- forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix)
47
- if torch_version_above_or_equal_2:
48
- self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
49
- self.register_forward_hook(forward_hook, with_kwargs=True)
50
- else:
51
- self.register_forward_pre_hook(forward_pre_hook)
52
- self.register_forward_hook(forward_hook)
53
- self.register_backward_hook(backward_hook)
47
+ hook_set = hook_build_func(prefix)
48
+ register_forward_pre_hook(self, hook_set.forward_pre_hook)
49
+ register_forward_hook(self, hook_set.forward_hook)
50
+ self.register_backward_hook(hook_set.backward_hook)
54
51
 
55
52
  def __call__(self, *args, **kwargs):
56
53
  changed = False
@@ -79,13 +76,7 @@ class HOOKModule(nn.Module):
79
76
  if len(self._backward_hooks) > 0:
80
77
  full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
81
78
  for hook in self._forward_pre_hooks.values():
82
- result_args, result_kwargs = hook(self, args, kwargs)
83
- if result_args is not None:
84
- if not isinstance(result_args, tuple):
85
- result_args = (result_args,)
86
- args = result_args
87
- if result_kwargs is not None:
88
- kwargs = result_kwargs
79
+ hook(self, args, kwargs)
89
80
  bw_hook = None
90
81
  if len(full_backward_hooks) > 0:
91
82
  bw_hook = full_hooks.BackwardHook(self, full_backward_hooks)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from msprobe.pytorch.hook_module.api_register import get_api_register
19
+
20
+
21
+ def wrap_jit_script_func():
22
+ def patched_script(*args, **kwargs):
23
+ all_api_registered = api_register.all_api_registered
24
+ if all_api_registered:
25
+ api_register.restore_all_api()
26
+ result = original_script(*args, **kwargs)
27
+ if all_api_registered:
28
+ api_register.register_all_api()
29
+ return result
30
+
31
+ original_script = torch.jit.script
32
+ api_register = get_api_register()
33
+ torch.jit.script = patched_script
@@ -0,0 +1,68 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from contextlib import nullcontext
18
+
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.core.common.utils import replace_last_occurrence
21
+ from msprobe.core.hook_manager import BaseHookManager, HookSet
22
+ from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2
23
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
+
25
+
26
+ class PytorchHookManager(BaseHookManager):
27
+ @property
28
+ def _is_recompute(self):
29
+ return is_recomputation()
30
+
31
+ @staticmethod
32
+ def _no_grad_context():
33
+ return nullcontext()
34
+
35
+ @staticmethod
36
+ def _add_count(name):
37
+ HOOKModule.add_module_count(name)
38
+
39
+ @staticmethod
40
+ def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
41
+ kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
42
+ output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
43
+ return kwargs, output
44
+
45
+ def build_hook(self, hook_type, name):
46
+ if hook_type == Const.API:
47
+ full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
48
+ else:
49
+ full_forward_name = name
50
+ full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
51
+ hookset = HookSet(
52
+ forward_hook=self._build_forward_hook(hook_type, full_forward_name),
53
+ forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
54
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name)
55
+ )
56
+ return hookset
57
+
58
+ def _need_exchange(self, module):
59
+ return True
60
+
61
+ def _get_params_dict(self, module):
62
+ params_dict = {}
63
+ if self.config.task != Const.STRUCTURE:
64
+ params_dict = {
65
+ key.split(Const.SEP)[-1]: value
66
+ for key, value in module.named_parameters(recurse=False)
67
+ }
68
+ return params_dict