mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,9 +16,10 @@
16
16
  import os
17
17
  import re
18
18
 
19
- from msprobe.core.common.const import Const
19
+ from msprobe.core.common.const import Const, FileCheckConst
20
20
  from msprobe.core.common.exceptions import MsprobeException
21
- from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
21
+ from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid, \
22
+ FileChecker
22
23
  from msprobe.core.common.log import logger
23
24
  from msprobe.core.common.utils import is_int
24
25
  from msprobe.core.common_config import BaseConfig, CommonConfig
@@ -66,6 +67,7 @@ class TensorConfig(BaseConfig):
66
67
  check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
67
68
  check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
68
69
  check_crt_valid(os.path.join(self.tls_path, "client.crt"))
70
+ check_crt_valid(os.path.join(self.tls_path, "client.key"), True)
69
71
 
70
72
  if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
71
73
  raise Exception(f"host: {self.host} is invalid.")
@@ -95,6 +97,8 @@ class OverflowCheckConfig(BaseConfig):
95
97
  def check_overflow_config(self):
96
98
  if self.overflow_nums is not None and not is_int(self.overflow_nums):
97
99
  raise Exception("overflow_num is invalid")
100
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
101
+ raise Exception("overflow_nums should be -1 or positive integer")
98
102
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
99
103
  raise Exception("check_mode is invalid")
100
104
 
@@ -148,7 +152,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
148
152
  self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
149
153
  ):
150
154
  msg = (
151
- f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
155
+ f"You need to and can only set fuzz_device as {DeviceType.CPU} "
152
156
  f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
153
157
  )
154
158
  logger.error_log_with_exp(
@@ -271,13 +275,13 @@ class RunUTConfig(BaseConfig):
271
275
 
272
276
  @classmethod
273
277
  def check_nfs_path_config(cls, nfs_path):
274
- if nfs_path and not os.path.exists(nfs_path):
275
- raise Exception("nfs_path: %s does not exist" % nfs_path)
278
+ if nfs_path:
279
+ FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
276
280
 
277
281
  @classmethod
278
282
  def check_tls_path_config(cls, tls_path):
279
- if tls_path and not os.path.exists(tls_path):
280
- raise Exception("tls_path: %s does not exist" % tls_path)
283
+ if tls_path:
284
+ FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
281
285
 
282
286
  def check_run_ut_config(self):
283
287
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
@@ -303,28 +307,25 @@ class GradToolConfig(BaseConfig):
303
307
  check_bounds(self.bounds)
304
308
 
305
309
 
310
+ class StructureConfig(BaseConfig):
311
+ def __init__(self, json_config):
312
+ super().__init__(json_config)
313
+
314
+
315
+ TaskDict = {
316
+ Const.TENSOR: TensorConfig,
317
+ Const.STATISTICS: StatisticsConfig,
318
+ Const.OVERFLOW_CHECK: OverflowCheckConfig,
319
+ Const.FREE_BENCHMARK: FreeBenchmarkCheckConfig,
320
+ Const.RUN_UT: RunUTConfig,
321
+ Const.GRAD_PROBE: GradToolConfig,
322
+ Const.STRUCTURE: StructureConfig
323
+ }
324
+
325
+
306
326
  def parse_task_config(task, json_config):
307
- default_dic = {}
308
- if task == Const.TENSOR:
309
- config_dic = json_config.get(Const.TENSOR, default_dic)
310
- return TensorConfig(config_dic)
311
- elif task == Const.STATISTICS:
312
- config_dic = json_config.get(Const.STATISTICS, default_dic)
313
- return StatisticsConfig(config_dic)
314
- elif task == Const.OVERFLOW_CHECK:
315
- config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
316
- return OverflowCheckConfig(config_dic)
317
- elif task == Const.FREE_BENCHMARK:
318
- config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
319
- return FreeBenchmarkCheckConfig(config_dic)
320
- elif task == Const.RUN_UT:
321
- config_dic = json_config.get(Const.RUN_UT, default_dic)
322
- return RunUTConfig(config_dic)
323
- elif task == Const.GRAD_PROBE:
324
- config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
325
- return GradToolConfig(config_dic)
326
- else:
327
- return StatisticsConfig(default_dic)
327
+ task_map = json_config.get(task, dict())
328
+ return TaskDict.get(task)(task_map)
328
329
 
329
330
 
330
331
  def parse_json_config(json_file_path, task):
@@ -15,22 +15,22 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from collections import namedtuple
18
+ from collections import namedtuple, defaultdict
19
19
 
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
22
  from msprobe.core.common.exceptions import DistributedNotInitializedError
23
23
  from msprobe.core.common.file_utils import create_directory
24
- from msprobe.core.common.utils import print_tools_ends_info
24
+ from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
27
  from msprobe.core.data_dump.scope import BaseScope
28
28
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
29
  from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import get_rank_if_initialized
30
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
31
  from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
32
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
- from msprobe.pytorch.hook_module.api_registry import api_register
33
+ from msprobe.pytorch.hook_module.api_register import get_api_register
34
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
35
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
36
 
@@ -50,19 +50,25 @@ class Service:
50
50
  self.switch = False
51
51
  self.inner_switch = False
52
52
  self.current_iter = 0
53
+ self.loop = 0
54
+ self.init_step = 0
53
55
  self.first_start = True
54
56
  self.current_rank = None
55
57
  self.dump_iter_dir = None
56
58
  self.should_stop_service = False
57
59
  self.attl = None
58
60
  self.params_grad_info = {}
61
+ self.hook_handle_dict = {}
59
62
  # 提前注册,确保注册尽可能多的API hook
63
+ self.api_register = get_api_register()
60
64
  self.register_api_hook()
65
+ self.init_for_debug_level()
61
66
 
62
67
  def build_hook(self, module_type, name):
63
68
  def pre_hook(api_or_module_name, module, args, kwargs):
64
69
  if not self.should_execute_hook(module_type, module, True):
65
70
  return args, kwargs
71
+ is_recompute = is_recomputation()
66
72
 
67
73
  self.inner_switch = True
68
74
  if module_type == BaseScope.Module_Type_Module:
@@ -77,7 +83,13 @@ class Service:
77
83
  return None, None
78
84
  if self.data_collector:
79
85
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
80
- self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output)
86
+ self.data_collector.forward_input_data_collect(
87
+ api_or_module_name,
88
+ module,
89
+ pid,
90
+ module_input_output,
91
+ is_recompute
92
+ )
81
93
 
82
94
  self.inner_switch = False
83
95
  return args, kwargs
@@ -101,7 +113,12 @@ class Service:
101
113
  if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
102
114
  for param_name, param in params_dict.items():
103
115
  if param.requires_grad:
104
- param.register_hook(grad_hook(module, ori_name, param_name))
116
+ name = ori_name + Const.SEP + param_name
117
+ old_handle = self.hook_handle_dict.get(name)
118
+ if old_handle and hasattr(old_handle, "remove"):
119
+ old_handle.remove()
120
+ handle = param.register_hook(grad_hook(module, ori_name, param_name))
121
+ self.hook_handle_dict[name] = handle
105
122
 
106
123
  def init_params_grad_info(module, params_dict):
107
124
  '''
@@ -125,6 +142,7 @@ class Service:
125
142
  def forward_hook(api_or_module_name, module, args, kwargs, output):
126
143
  if not self.should_execute_hook(module_type, module, True):
127
144
  return None
145
+ is_recompute = is_recomputation()
128
146
 
129
147
  self.inner_switch = True
130
148
  if self.config.online_run_ut:
@@ -147,10 +165,15 @@ class Service:
147
165
  if module_type == BaseScope.Module_Type_Module:
148
166
  api_or_module_name = module.mindstudio_reserved_name[-1]
149
167
  self.data_collector.update_api_or_module_name(api_or_module_name)
150
- params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)}
151
- setattr(module_input_output, Const.PARAMS, params_dict)
168
+ params_dict = {}
169
+ if self.config.task != Const.STRUCTURE:
170
+ params_dict = {
171
+ key.split(Const.SEP)[-1]: value
172
+ for key, value in module.named_parameters(recurse=False)
173
+ }
174
+ setattr(module_input_output, Const.PARAMS, params_dict)
152
175
  # 判断是否需要注册参数hook
153
- if not hasattr(module, 'params_grad_name') and params_dict:
176
+ if params_dict:
154
177
  ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
155
178
  grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
156
179
  # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
@@ -160,7 +183,8 @@ class Service:
160
183
  api_or_module_name,
161
184
  module,
162
185
  pid,
163
- module_input_output
186
+ module_input_output,
187
+ is_recompute
164
188
  )
165
189
  init_params_grad_info(module, params_dict)
166
190
  else:
@@ -169,7 +193,8 @@ class Service:
169
193
  api_or_module_name,
170
194
  module,
171
195
  pid,
172
- module_input_output
196
+ module_input_output,
197
+ is_recompute
173
198
  )
174
199
 
175
200
  if self.data_collector.if_return_forward_new_output():
@@ -185,6 +210,7 @@ class Service:
185
210
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
186
211
  if not self.should_execute_hook(module_type, module, False):
187
212
  return
213
+ is_recompute = is_recomputation()
188
214
 
189
215
  self.inner_switch = True
190
216
  if module_type == BaseScope.Module_Type_Module:
@@ -198,7 +224,13 @@ class Service:
198
224
  if self.data_collector:
199
225
  # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
200
226
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
201
- self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
227
+ self.data_collector.backward_data_collect(
228
+ api_or_module_name,
229
+ module,
230
+ pid,
231
+ module_input_output,
232
+ is_recompute
233
+ )
202
234
  self.inner_switch = False
203
235
 
204
236
  pid = os.getpid()
@@ -217,6 +249,10 @@ class Service:
217
249
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
218
250
 
219
251
  def start(self, model):
252
+ self.current_iter = self.loop + self.init_step
253
+ self.data_collector.update_iter(self.current_iter)
254
+ if self.config.level == Const.LEVEL_DEBUG:
255
+ return
220
256
  if self.need_stop_service():
221
257
  return
222
258
 
@@ -231,6 +267,8 @@ class Service:
231
267
  if self.config.rank and self.current_rank not in self.config.rank:
232
268
  return
233
269
  self.register_module_hook()
270
+ if self.config.level == Const.LEVEL_MIX:
271
+ register_optimizer_hook(self.data_collector)
234
272
  self.first_start = False
235
273
  if self.config.online_run_ut and torch_version_above_or_equal_2:
236
274
  run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
@@ -241,6 +279,8 @@ class Service:
241
279
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
242
280
 
243
281
  def stop(self):
282
+ if self.config.level == Const.LEVEL_DEBUG:
283
+ return
244
284
  if self.should_stop_service:
245
285
  return
246
286
  if self.config.step and self.current_iter not in self.config.step:
@@ -255,18 +295,21 @@ class Service:
255
295
  return
256
296
  if self.config.async_dump:
257
297
  self.data_collector.fill_stack_tensor_data()
258
- self.data_collector.data_processor.dump_async_data()
298
+ if self.config.task == Const.TENSOR:
299
+ self.data_collector.data_processor.dump_async_data()
259
300
  self.data_collector.write_json()
260
301
 
261
302
  def step(self):
303
+ if self.config.level == Const.LEVEL_DEBUG:
304
+ return
262
305
  if self.should_stop_service:
263
306
  return
264
307
  if self.config.async_dump:
265
308
  self.data_collector.fill_stack_tensor_data()
266
- self.data_collector.data_processor.dump_async_data()
309
+ if self.config.task == Const.TENSOR:
310
+ self.data_collector.data_processor.dump_async_data()
267
311
  self.data_collector.write_json()
268
- self.current_iter += 1
269
- self.data_collector.update_iter(self.current_iter)
312
+ self.loop += 1
270
313
  self.reset_status()
271
314
 
272
315
  def need_stop_service(self):
@@ -319,26 +362,22 @@ class Service:
319
362
  else:
320
363
  dump_data_dir = None
321
364
 
322
- dump_file_path = os.path.join(dump_dir, "dump.json")
323
- stack_file_path = os.path.join(dump_dir, "stack.json")
324
- construct_file_path = os.path.join(dump_dir, "construct.json")
325
- free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
326
- self.data_collector.update_dump_paths(
327
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
328
- )
365
+ dump_path_aggregation = DumpPathAggregation()
366
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
367
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
368
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
369
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
370
+ dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
371
+ self.data_collector.update_dump_paths(dump_path_aggregation)
329
372
  self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
330
373
 
331
374
  def register_api_hook(self):
332
375
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
333
376
  logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
334
- api_register.initialize_hook(
335
- functools.partial(self.build_hook, BaseScope.Module_Type_API),
336
- self.config.online_run_ut
377
+ self.api_register.initialize_hook(
378
+ functools.partial(self.build_hook, BaseScope.Module_Type_API)
337
379
  )
338
- api_register.api_modularity()
339
-
340
- if self.config.level == Const.LEVEL_MIX:
341
- register_optimizer_hook(self.data_collector)
380
+ self.api_register.register_all_api()
342
381
 
343
382
  def register_module_hook(self):
344
383
  if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
@@ -373,13 +412,13 @@ class Service:
373
412
  if self.config.nfs_path:
374
413
  self.attl.upload("end")
375
414
  elif self.attl.socket_manager is not None:
376
- logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
415
+ logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
377
416
  self.attl.socket_manager.send_stop_signal()
378
417
 
379
418
  def reset_status(self):
380
419
  ModuleProcesser.reset_module_stats()
381
420
  HOOKModule.reset_module_stats()
382
- self.data_collector.data_writer.reset_cache()
421
+ self.data_collector.reset_status()
383
422
  self.params_grad_info.clear()
384
423
 
385
424
  if self.config.level == Const.LEVEL_L2:
@@ -389,3 +428,46 @@ class Service:
389
428
  return
390
429
  if self.config.rank and self.current_rank not in self.config.rank:
391
430
  return
431
+
432
+ def init_for_debug_level(self):
433
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
434
+ return
435
+ try:
436
+ self.current_rank = get_rank_if_initialized()
437
+ except DistributedNotInitializedError:
438
+ self.current_rank = None
439
+
440
+ # dir: dump_path -- rank{} -- debug.json
441
+ self.dump_iter_dir = self.config.dump_path
442
+ cur_rank = self.current_rank if self.current_rank is not None else ''
443
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
444
+ create_directory(dump_dir)
445
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
446
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
447
+ create_directory(dump_data_dir)
448
+ else:
449
+ dump_data_dir = None
450
+
451
+ dump_path_aggregation = DumpPathAggregation()
452
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
453
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
454
+ self.data_collector.update_dump_paths(dump_path_aggregation)
455
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
456
+
457
+ self.debug_variable_counter = defaultdict(int)
458
+
459
+ def save(self, variable, name, save_backward):
460
+ if self.config.level != Const.LEVEL_DEBUG:
461
+ return
462
+ count = self.debug_variable_counter[name]
463
+ self.debug_variable_counter[name] += 1
464
+
465
+ name_with_count = f"{name}.{count}"
466
+ grad_name_with_count = f"{name}_grad.{count}"
467
+
468
+ # forward save
469
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
470
+
471
+ # backward save
472
+ if save_backward:
473
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
@@ -16,18 +16,19 @@
16
16
  import re
17
17
 
18
18
  from msprobe.core.common.const import Const
19
- from msprobe.core.common.file_utils import load_json
19
+ from msprobe.core.common.file_utils import load_json, save_json
20
20
  from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
21
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
22
22
  from msprobe.visualization.graph.graph import Graph
23
23
  from msprobe.visualization.graph.node_op import NodeOp
24
- from msprobe.visualization.utils import save_json_file, GraphConst
24
+ from msprobe.visualization.utils import GraphConst
25
25
 
26
26
 
27
27
  class GraphBuilder:
28
28
  backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
- # 匹配以大写字母开头,后接任意字母,并以Template(结尾
30
- template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
29
+ forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
30
+ # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
31
+ template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
31
32
 
32
33
  @staticmethod
33
34
  def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
@@ -50,6 +51,7 @@ class GraphBuilder:
50
51
  graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
51
52
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
52
53
  GraphBuilder._collect_apis_between_modules(graph)
54
+ GraphBuilder._add_parameters_grad(graph, data_dict)
53
55
  return graph
54
56
 
55
57
  @staticmethod
@@ -72,7 +74,7 @@ class GraphBuilder:
72
74
  if config.task:
73
75
  result[GraphConst.JSON_TASK_KEY] = config.task
74
76
  result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
75
- save_json_file(filename, result)
77
+ save_json(filename, result, indent=4)
76
78
 
77
79
  @staticmethod
78
80
  def _simplify_stack(stack_dict):
@@ -113,12 +115,17 @@ class GraphBuilder:
113
115
  如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
114
116
  """
115
117
  # 匹配以.backward.后跟一个或多个数字结尾的模式
116
- backward_pattern = r"(\.backward\.)(\d+)$"
117
- forward_pattern = r"(\.forward\.)(\d+)$"
118
- if re.search(backward_pattern, subnode_id) and not upnode_id:
119
- forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
118
+ if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id:
119
+ forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id))
120
120
  if forward_upnode_id:
121
- new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
121
+ new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id)
122
+ if new_upnode_id in construct_dict:
123
+ return new_upnode_id
124
+ # 匹配以.backward结尾的节点
125
+ if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id:
126
+ forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD))
127
+ if forward_upnode_id:
128
+ new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD)
122
129
  if new_upnode_id in construct_dict:
123
130
  return new_upnode_id
124
131
  return upnode_id
@@ -148,6 +155,8 @@ class GraphBuilder:
148
155
  input_data, output_data = get_input_output(node_data, node.id)
149
156
  # 更新数据
150
157
  node.set_input_output(input_data, output_data)
158
+ if GraphConst.BATCH_P2P in name:
159
+ GraphBuilder._extract_batch_p2p_info(node, node_data)
151
160
  # 反向节点使用对应前向节点的堆栈信息
152
161
  # 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
153
162
  if (not node_stack_info and
@@ -164,6 +173,24 @@ class GraphBuilder:
164
173
  node.add_upnode(upnode)
165
174
  return node
166
175
 
176
+ @staticmethod
177
+ def _is_valid_batch_p2p_output(param_list):
178
+ if not isinstance(param_list, list) or not param_list:
179
+ return False
180
+ if not isinstance(param_list[0], list) or not param_list[0]:
181
+ return False
182
+ return True
183
+
184
+ @staticmethod
185
+ def _extract_batch_p2p_info(node, node_data):
186
+ param_list = node_data.get(Const.OUTPUT, [])
187
+ # 数据格式:"output": [[{param1}, {param2}, ...]]
188
+ if GraphBuilder._is_valid_batch_p2p_output(param_list):
189
+ for param in param_list[0]:
190
+ info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
191
+ GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
192
+ node.batch_p2p_info.append(info)
193
+
167
194
  @staticmethod
168
195
  def _collect_apis_between_modules(graph):
169
196
  """
@@ -209,6 +236,44 @@ class GraphBuilder:
209
236
 
210
237
  graph.root.subnodes = output
211
238
 
239
+ @staticmethod
240
+ def _add_parameters_grad(graph, data_dict):
241
+ """
242
+ 将parameters_grad信息添加到graph中,
243
+ 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
244
+
245
+ 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
246
+ 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
247
+ """
248
+ prefixes = []
249
+ suffix = Const.SEP + Const.PARAMS_GRAD
250
+ for node_id in data_dict.keys():
251
+ if node_id not in graph.node_map and node_id.endswith(suffix):
252
+ prefixes.append(node_id.replace(suffix, ''))
253
+
254
+ max_info = {prefix: 0 for prefix in prefixes}
255
+
256
+ for key in graph.node_map.keys():
257
+ for prefix in prefixes:
258
+ # 构建正则表达式,匹配以 "backward.数字" 结尾的键
259
+ pattern = re.compile(r'^' + re.escape(prefix) + r'\.backward\.(\d+)$')
260
+ match = pattern.match(key)
261
+ if match:
262
+ num = int(match.group(1))
263
+ if num > max_info[prefix]:
264
+ max_info[prefix] = num
265
+
266
+ for prefix, num in max_info.items():
267
+ node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
268
+ node = graph.get_node(node_id)
269
+ if node:
270
+ parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
271
+ # 添加输入输出数据
272
+ node_data = data_dict.get(parameters_grad_node_id, {})
273
+ input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
274
+ # 更新数据
275
+ graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
276
+
212
277
 
213
278
  class GraphExportConfig:
214
279
  def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  import re
16
- import math
17
16
  from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
17
  from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
18
  from msprobe.visualization.utils import GraphConst
@@ -23,7 +22,7 @@ from msprobe.core.compare.acc_compare import ModeConfig
23
22
  # 用于将节点名字解析成对应的NodeOp的规则
24
23
  op_patterns = [
25
24
  # NodeOp.module
26
- r'^(Module.|Cell.)',
25
+ r'^(Module.|Cell.|optimizer|clip_grad)',
27
26
  # NodeOp.function_api
28
27
  r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
29
28
  ]
@@ -57,8 +56,8 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
57
56
  from msprobe.pytorch.compare.pt_compare import PTComparator
58
57
  return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
59
58
  else:
60
- from msprobe.mindspore.compare.ms_compare import MSComparator
61
- ms_comparator = MSComparator(mode_config)
59
+ from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
60
+ ms_comparator = MSComparator(mode_config, MappingConfig())
62
61
  ms_comparator.cross_frame = is_cross_frame
63
62
  return ms_comparator.do_multi_process(dump_path_param, csv_path)
64
63
 
@@ -120,11 +119,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
120
119
  return True
121
120
 
122
121
 
123
- def format_node_data(data_dict):
122
+ def format_node_data(data_dict, node_id=None):
124
123
  """
125
- 批量进行节点数据的输出
124
+ 删除节点数据中不需要展示的字段
126
125
  """
127
126
  del_list = ['requires_grad', 'full_op_name']
127
+ if node_id and GraphConst.BATCH_P2P in node_id:
128
+ del_list.extend(['op', 'peer', 'tag', 'group_id'])
128
129
  for _, value in data_dict.items():
129
130
  if not isinstance(value, dict):
130
131
  continue