mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -14,22 +14,25 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import functools
17
- import os
18
17
  import inspect
18
+ import os
19
19
 
20
20
  import torch
21
21
  import torch.distributed as dist
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import load_yaml
24
25
  from msprobe.core.data_dump.api_registry import ApiRegistry
25
26
  from msprobe.pytorch.common.log import logger
26
27
  from msprobe.pytorch.common.utils import (
27
- torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
28
+ torch_without_guard_version,
29
+ is_gpu,
30
+ torch_device_guard,
31
+ parameter_adapter
28
32
  )
29
33
  from msprobe.pytorch.function_factory import npu_custom_functions
30
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
31
35
  from msprobe.pytorch.hook_module.utils import dynamic_import_op
32
- from msprobe.core.common.file_utils import load_yaml
33
36
 
34
37
  try:
35
38
  import mindspeed.ops
@@ -38,42 +41,46 @@ except ImportError:
38
41
  else:
39
42
  mindspeed_enable = True
40
43
 
41
-
42
44
  torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
43
45
 
44
46
  _inner_used_api = {}
45
47
  _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
46
48
  _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
49
+ dist_data_collect_func = {}
50
+ dist_batch_data_collect_func = []
47
51
 
48
52
  _api_types = {
49
53
  Const.PT_FRAMEWORK: {
50
- Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
51
- Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
52
- Const.PT_API_TYPE_TORCH: (torch, (torch,)),
53
- Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
54
- Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
54
+ Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)),
55
+ Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)),
56
+ Const.PT_API_TYPE_TORCH: ((torch,), (torch,)),
57
+ Const.PT_API_TYPE_VF: ((torch._C._VariableFunctionsClass,), (torch._VF,)),
58
+ Const.PT_API_TYPE_DIST: ((dist,), (dist, dist.distributed_c10d))
55
59
  }
56
60
  }
57
61
  if not is_gpu:
58
62
  import torch_npu
63
+
59
64
  if torch_without_guard_version:
60
65
  _api_types.get(Const.PT_FRAMEWORK).update(
61
66
  {
62
- Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
67
+ Const.PT_API_TYPE_NPU: ((torch.ops.npu, torch_npu), (torch_npu, torch.ops.npu)),
63
68
  }
64
69
  )
65
70
  else:
66
71
  _api_types.get(Const.PT_FRAMEWORK).update(
67
- {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
72
+ {Const.PT_API_TYPE_NPU: ((torch_npu._C._VariableFunctionsClass,), (torch_npu,))}
68
73
  )
69
74
  _api_types.get(Const.PT_FRAMEWORK).update(
70
75
  {
71
- Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
72
- torch_npu.distributed.distributed_c10d))
76
+ Const.PT_API_TYPE_NPU_DIST: (
77
+ (torch_npu.distributed,),
78
+ (torch_npu.distributed, torch_npu.distributed.distributed_c10d)
79
+ )
73
80
  }
74
81
  )
75
82
  if mindspeed_enable:
76
- _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
83
+ _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: ((mindspeed.ops,), (mindspeed.ops,))})
77
84
  mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
78
85
  mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
79
86
  dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
@@ -94,16 +101,48 @@ def dist_module_forward(module, *args, **kwargs):
94
101
  use_async_op_flag = False
95
102
  logger.warning(f"fail to get dist api's func signature because {e}, no wait")
96
103
 
97
- if use_async_op_flag or module.api_name in ["isend", "irecv"]:
98
- if handle and hasattr(handle, 'wait'):
99
- handle.wait()
100
- if module.api_name == "batch_isend_irecv":
101
- if isinstance(handle, list):
102
- for req in handle:
103
- req.wait()
104
+ def create_async_callback_func(catch_func):
105
+ full_name = module.full_forward_name if hasattr(module, "full_forward_name") else None
106
+
107
+ def store_data():
108
+ catch_func(module, full_name, args, kwargs, handle)
109
+
110
+ return store_data
111
+
112
+ if use_async_op_flag or module.api_name in ['isend', 'irecv']:
113
+ dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook)
114
+ if module.api_name == 'batch_isend_irecv':
115
+ dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)])
104
116
  return handle
105
117
 
106
118
 
119
+ def redirect_wait():
120
+ if hasattr(dist, "Work"):
121
+ from torch.distributed import Work
122
+ else:
123
+ from torch._C._distributed_c10d import Work
124
+ origin_wait = Work.wait
125
+
126
+ def wrapped_wait(work):
127
+ def wrapped_wait(*args, **kwargs):
128
+ origin_wait(*args, **kwargs)
129
+ if args[0] in dist_data_collect_func:
130
+ store_func = dist_data_collect_func.pop(args[0])
131
+ store_func()
132
+ return
133
+ for value in dist_batch_data_collect_func:
134
+ if args[0] in value[0]:
135
+ value[0].remove(args[0])
136
+ if len(value[0]) == 0:
137
+ store_func = value[1]
138
+ store_func()
139
+ return
140
+
141
+ return wrapped_wait
142
+
143
+ Work.wait = wrapped_wait(Work)
144
+
145
+
107
146
  def npu_module_forward(module, *args, **kwargs):
108
147
  if not module.need_hook:
109
148
  if module.api_name not in npu_custom_functions:
@@ -125,15 +164,14 @@ forward_methods = {
125
164
  class ApiTemplate(HOOKModule):
126
165
  def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
127
166
  self.api_name = api_name
128
- self.api_func = api_func
129
167
  self.prefix = prefix
130
168
  self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
131
169
  self.need_hook = need_hook
132
170
  self.device = device
171
+ self.op_is_distributed = prefix == Const.DIST_API_TYPE_PREFIX
133
172
  if self.need_hook:
134
173
  super().__init__(hook_build_func)
135
- if prefix == Const.DIST_API_TYPE_PREFIX:
136
- self.op_is_distributed = True
174
+ self.api_func = api_func
137
175
 
138
176
  @torch_device_guard
139
177
  def forward(self, *args, **kwargs):
@@ -14,50 +14,30 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import functools
17
- import threading
18
17
  from collections import defaultdict
19
18
 
20
19
  import torch
21
20
  import torch.nn as nn
22
21
  import torch.utils.hooks as full_hooks
23
22
 
24
- from msprobe.core.common.runtime import Runtime
25
- from msprobe.core.common.utils import ThreadSafe
26
- from msprobe.pytorch.common.utils import register_forward_pre_hook, register_forward_hook
23
+ from msprobe.pytorch.common.utils import register_forward_pre_hook
27
24
 
28
25
 
29
26
  class HOOKModule(nn.Module):
30
27
  module_count = defaultdict(int)
31
- inner_stop_hook = defaultdict(bool)
32
28
 
33
29
  def __init__(self, hook_build_func) -> None:
34
30
  super(HOOKModule, self).__init__()
35
- self.has_overflow = False
36
- self.tid = threading.get_ident()
37
- self.stop_hook = HOOKModule.inner_stop_hook.get(self.tid, False)
38
-
39
- if not self.stop_hook:
40
- self.forward_data_collected = False
41
-
42
- if not Runtime.is_running:
43
- return
44
- prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
45
- ThreadSafe.acquire()
46
- if callable(hook_build_func):
47
- hook_set = hook_build_func(prefix)
48
- register_forward_pre_hook(self, hook_set.forward_pre_hook)
49
- register_forward_hook(self, hook_set.forward_hook)
50
- self.register_backward_hook(hook_set.backward_hook)
31
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
32
+ op_is_distributed = self.op_is_distributed if hasattr(self, "op_is_distributed") else False
33
+ if callable(hook_build_func):
34
+ hook_set = hook_build_func(prefix)
35
+ register_forward_pre_hook(self, hook_set.forward_pre_hook)
36
+ if op_is_distributed:
37
+ self.distributed_forward_hook = hook_set.distributed_forward_hook
51
38
 
52
39
  def __call__(self, *args, **kwargs):
53
- changed = False
54
- if not self.stop_hook:
55
- HOOKModule.inner_stop_hook[self.tid] = True
56
- changed = True
57
- result = self._call_func(*args, **kwargs)
58
- if changed:
59
- HOOKModule.inner_stop_hook[self.tid] = False
60
- return result
40
+ return self._call_func(*args, **kwargs)
61
41
 
62
42
  @staticmethod
63
43
  def reset_module_stats():
@@ -13,13 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
-
16
+ import functools
17
+ import threading
17
18
  from contextlib import nullcontext
18
19
 
20
+ import torch
21
+
19
22
  from msprobe.core.common.const import Const
20
- from msprobe.core.common.utils import replace_last_occurrence
23
+ from msprobe.core.common.runtime import Runtime
24
+ from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe
25
+ from msprobe.core.data_dump.data_processor.base import (ModuleForwardInputsOutputs)
21
26
  from msprobe.core.hook_manager import BaseHookManager, HookSet
22
- from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2
27
+ from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2, register_forward_hook
23
28
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
29
 
25
30
 
@@ -37,23 +42,65 @@ class PytorchHookManager(BaseHookManager):
37
42
  HOOKModule.add_module_count(name)
38
43
 
39
44
  @staticmethod
40
- def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
41
- kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
42
- output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
45
+ def _get_count(name):
46
+ return HOOKModule.get_module_count(name)
47
+
48
+ @staticmethod
49
+ def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs):
50
+ if hook_type == Const.API:
51
+ kwargs = kwargs_or_output
52
+ output = output_or_kwargs
53
+ else:
54
+ kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
55
+ output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
43
56
  return kwargs, output
44
57
 
45
58
  def build_hook(self, hook_type, name):
46
59
  if hook_type == Const.API:
47
- full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
60
+ hook_set = HookSet(
61
+ forward_pre_hook=self._build_forward_pre_hook(hook_type, name),
62
+ distributed_forward_hook=self._build_distributed_forward_hook()
63
+ )
48
64
  else:
49
- full_forward_name = name
50
- full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD)
51
- hookset = HookSet(
52
- forward_hook=self._build_forward_hook(hook_type, full_forward_name),
53
- forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
54
- backward_hook=self._build_backward_hook(hook_type, full_backward_name)
55
- )
56
- return hookset
65
+ full_backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD)
66
+ hook_set = HookSet(
67
+ forward_hook=self._build_forward_hook(hook_type, name),
68
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name)
69
+ )
70
+ return hook_set
71
+
72
+ def _register_forward_hook(self, module, api_name):
73
+ if not hasattr(module, 'msprobe_forward_hook'):
74
+ register_forward_hook(module, self._build_forward_hook(Const.API, api_name))
75
+ setattr(module, 'msprobe_forward_hook', True)
76
+
77
+ def _register_backward_hook(self, module, full_backward_name, args):
78
+ pass
79
+
80
+ def _register_backward_pre_hook(self, module, full_backward_name, output):
81
+ var = output
82
+ while not isinstance(var, torch.Tensor):
83
+ if isinstance(var, dict):
84
+ var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
85
+ elif isinstance(var, (list, tuple)):
86
+ if var:
87
+ var = var[0]
88
+ else:
89
+ return output
90
+ else:
91
+ return output
92
+
93
+ if not (var.requires_grad and torch.is_grad_enabled()):
94
+ return output
95
+
96
+ grad_fn = var.grad_fn
97
+ if grad_fn is not None:
98
+ backward_hook = self._build_backward_hook(Const.API, full_backward_name)
99
+ wrapper = functools.partial(backward_hook, module)
100
+ functools.update_wrapper(wrapper, backward_hook)
101
+ grad_fn.register_hook(wrapper)
102
+
103
+ return output
57
104
 
58
105
  def _need_exchange(self, module):
59
106
  return True
@@ -66,3 +113,25 @@ class PytorchHookManager(BaseHookManager):
66
113
  for key, value in module.named_parameters(recurse=False)
67
114
  }
68
115
  return params_dict
116
+
117
+ def _build_distributed_forward_hook(self):
118
+ def distributed_forward_hook(module, full_name, args, kwargs, output):
119
+ if not full_name or not Runtime.is_running:
120
+ return
121
+
122
+ tid = threading.get_ident()
123
+ with ThreadSafe():
124
+ BaseHookManager.inner_switch[tid] = True
125
+ self.data_collector.update_api_or_module_name(full_name)
126
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
127
+ with self._no_grad_context():
128
+ self.data_collector.forward_output_data_collect(
129
+ full_name,
130
+ module,
131
+ self._pid,
132
+ module_input_output,
133
+ self._is_recompute
134
+ )
135
+ BaseHookManager.inner_switch[tid] = False
136
+
137
+ return distributed_forward_hook
@@ -0,0 +1,140 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import importlib
18
+ import types
19
+
20
+ import torch
21
+
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.pytorch.common.utils import torch_version_above_or_equal_2
24
+ from msprobe.pytorch.hook_module.api_register import get_api_register
25
+
26
+ if torch_version_above_or_equal_2:
27
+ from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks
28
+
29
+
30
+ def wrap_jit_script_func():
31
+ def patched_script(*args, **kwargs):
32
+ all_api_registered = api_register.all_api_registered
33
+ if all_api_registered:
34
+ api_register.restore_all_api()
35
+ result = original_script(*args, **kwargs)
36
+ if all_api_registered:
37
+ api_register.register_all_api()
38
+ return result
39
+
40
+ original_script = torch.jit.script
41
+ api_register = get_api_register()
42
+ torch.jit.script = patched_script
43
+
44
+
45
+ def wrap_compile_script_func():
46
+ def _patched_convert_frame(compiler_fn, hooks):
47
+ """
48
+ 在调用原 convert_frame 生成的 _convert_frame 之前恢复 API,
49
+ 调用完之后再重新注册所有 API。
50
+ """
51
+ # 拿到原来 inner 版的 _convert_frame
52
+ inner_convert = _orig_convert_frame(compiler_fn, hooks)
53
+
54
+ def _wrapped(frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state):
55
+ reg = get_api_register()
56
+ # 进入前 restore
57
+ reg.restore_all_api()
58
+ try:
59
+ result = inner_convert(frame, cache_size, hooks, frame_state)
60
+ except Exception:
61
+ # 异常时也要确保 register
62
+ reg.register_all_api()
63
+ raise
64
+ # 正常结束后 register
65
+ reg.register_all_api()
66
+ return result
67
+
68
+ # 保留原属性以兼容
69
+ _wrapped._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
70
+ _wrapped._clone_with_backend = lambda backend: _patched_convert_frame(backend,
71
+ hooks) # type: ignore[attr-defined]
72
+ return _wrapped
73
+
74
+ import torch._dynamo.convert_frame as _cf_mod
75
+ _cf_mod.convert_frame = _patched_convert_frame
76
+
77
+
78
+ def patch_dynamo_compile():
79
+ cf = importlib.import_module("torch._dynamo.convert_frame")
80
+ if not hasattr(cf, "_compile"):
81
+ logger.warning("No found torch._dynamo.convert_frame._compile")
82
+
83
+ original = cf._compile
84
+ if getattr(original, "__msprobe_patched__", False):
85
+ return
86
+
87
+ @functools.wraps(original)
88
+ def wrapped(*args, **kwargs):
89
+ result = None
90
+ try:
91
+ reg = get_api_register()
92
+ reg.restore_all_api()
93
+ except Exception as e:
94
+ logger.warning(f"[msprobe] Pre restore_all_api failed: {e}")
95
+ return result
96
+
97
+ try:
98
+ result = original(*args, **kwargs)
99
+ except Exception:
100
+ logger.warning("[msprobe] _compile execution failed (returning None)")
101
+ result = None
102
+ finally:
103
+ try:
104
+ reg = get_api_register()
105
+ reg.register_all_api() # 改成注册hook
106
+ except Exception as e:
107
+ logger.warning(f"[msprobe] Post register_all_api failed: {e}")
108
+ return result
109
+ wrapped.__msprobe_patched__ = True
110
+ wrapped.__msprobe_original__ = original
111
+ cf._compile = wrapped
112
+
113
+
114
+ def unpatch_dynamo_compile() -> bool:
115
+ # 预留取消patch接口
116
+ cf = importlib.import_module("torch._dynamo.convert_frame")
117
+ current = getattr(cf, "_compile", None)
118
+ if current is None:
119
+ return False
120
+ original = getattr(current, "__msprobe_original__", None)
121
+ if original is None:
122
+ return False
123
+ cf._compile = original
124
+ return True
125
+
126
+
127
+ def preprocess_func():
128
+ try:
129
+ from torch.utils._device import _device_constructors
130
+ _device_constructors()
131
+ except ImportError:
132
+ pass
133
+ except Exception as e:
134
+ logger.warning(f"Failed to execute _device_constructors. Error Details: {str(e)}")
135
+
136
+
137
+ def wrap_script_func():
138
+ wrap_jit_script_func()
139
+ if torch_version_above_or_equal_2:
140
+ patch_dynamo_compile()
@@ -1260,6 +1260,12 @@ torch_npu:
1260
1260
  - npu_scatter_nd_update
1261
1261
  - npu_prefetch
1262
1262
  - npu_dynamic_block_quant
1263
+ - npu_add_rms_norm
1264
+ - _npu_flash_attention
1265
+ - _npu_rotary_embedding
1266
+ - _npu_reshape_and_cache
1267
+ - _npu_paged_attention
1268
+ - npu_moe_gating_top_k
1263
1269
 
1264
1270
  aten:
1265
1271
  - signbit
@@ -79,7 +79,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
79
79
  for op, value in ops.items():
80
80
  tag = f"{vpp_name}/{op}"
81
81
  writer.add_scalar(tag, value, step)
82
- writer.flush()
82
+ writer.close()
83
83
 
84
84
 
85
85
  @recursion_depth_decorator("update_dict", max_depth=50)
@@ -111,3 +111,97 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val):
111
111
  @torch.no_grad()
112
112
  def get_nans(t):
113
113
  return torch.isnan(t).sum()
114
+
115
+
116
+ def check_tensor_dim(tensor, n):
117
+ """检查张量维度是否大于n
118
+ """
119
+ if not isinstance(tensor, torch.Tensor):
120
+ raise TypeError(
121
+ f"Input must be a PyTorch tensor. Got {type(tensor)} instead. "
122
+ f"Consider using torch.tensor() for conversion."
123
+ )
124
+
125
+ if tensor.dim() < n:
126
+ raise ValueError(
127
+ f"Tensor must have at least {n} dimensions. "
128
+ f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims."
129
+ )
130
+
131
+
132
+ @torch.no_grad()
133
+ def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3):
134
+ input_tensor = input_tensor.float()
135
+ try:
136
+ check_tensor_dim(input_tensor, 2)
137
+ except (TypeError, ValueError) as e:
138
+ logger.warning(f"Calculate max eigenvalue failed: {e}")
139
+ return torch.tensor(0)
140
+ in_features = input_tensor.shape[1]
141
+ u_tensor = torch.randn(in_features).to(input_tensor.device)
142
+ u_norm = u_tensor.norm()
143
+ if u_norm.item() == 0:
144
+ return torch.tensor(0)
145
+ u_tensor = u_tensor / u_tensor.norm()
146
+ input_seq = torch.matmul(input_tensor.T, input_tensor)
147
+ for _ in range(num_iterations):
148
+ v_tensor = torch.matmul(input_seq, u_tensor)
149
+ spectral_norm = torch.matmul(v_tensor.T, u_tensor)
150
+ v_norm = v_tensor.norm()
151
+ if v_norm > 0:
152
+ u_tensor = v_tensor / v_norm
153
+ else:
154
+ spectral_norm = torch.tensor(0)
155
+ break
156
+ return spectral_norm.sqrt()
157
+
158
+
159
+ @torch.no_grad()
160
+ def cal_entropy(qk_tensor, mask=None):
161
+ try:
162
+ check_tensor_dim(qk_tensor, 2)
163
+ except (TypeError, ValueError) as e:
164
+ logger.warning(f"Calculate max eigenvalue failed: {e}")
165
+ return torch.tensor(0), torch.tensor(0)
166
+ if mask is None:
167
+ mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to(
168
+ qk_tensor.device)
169
+ qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True)
170
+ qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf'))
171
+ softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1)
172
+ # softmax取QK矩阵最大值
173
+ softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1))
174
+ entropy = torch.mean(-torch.nansum(softmax_qkt *
175
+ torch.log(softmax_qkt), dim=1))
176
+ return entropy, softmax_max
177
+
178
+
179
+ @torch.no_grad()
180
+ def cal_qkt(q_h, k_h, order="s,b,h,d"):
181
+ # q_h shape is [s, b, h, d]
182
+ try:
183
+ check_tensor_dim(q_h, 4)
184
+ check_tensor_dim(k_h, 4)
185
+ except (TypeError, ValueError) as e:
186
+ logger.warning(f"Calculate qk tensor failed: {e}")
187
+ return torch.tensor(0)
188
+
189
+ if order == "s,b,h,d":
190
+ qkt = torch.matmul(
191
+ q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5
192
+ elif order == "b,s,h,d":
193
+ qkt = torch.matmul(
194
+ q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5
195
+ else:
196
+ logger.warning("Calculate qk tensor failed: Order unsupported.")
197
+ qkt = torch.tensor(0)
198
+ return qkt
199
+
200
+
201
+ @torch.no_grad()
202
+ def cal_stable_rank(weight: torch.Tensor):
203
+ eig = max_eigenvalue(weight)
204
+ if eig == torch.tensor(0):
205
+ return torch.tensor(0), torch.tensor(0)
206
+ f_norm = torch.norm(weight, p="fro")
207
+ return f_norm / eig, eig