mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -14,26 +14,29 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import re
17
+ import copy
17
18
  from dataclasses import dataclass
18
19
 
19
20
  from msprobe.core.common.const import Const
20
- from msprobe.core.common.file_utils import load_json, save_json
21
+ from msprobe.core.common.file_utils import load_json, load_construct_json
21
22
  from msprobe.core.common.utils import load_stack_json
23
+ from msprobe.core.common.log import logger
22
24
  from msprobe.visualization.builder.msprobe_adapter import get_input_output
23
- from msprobe.visualization.builder.msprobe_adapter import op_patterns
24
25
  from msprobe.visualization.graph.graph import Graph
25
26
  from msprobe.visualization.graph.node_op import NodeOp
26
27
  from msprobe.visualization.utils import GraphConst
28
+ from msprobe.visualization.db_utils import node_to_db, config_to_db
27
29
 
28
30
 
29
31
  class GraphBuilder:
30
32
  backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
31
33
  forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
32
34
  # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
33
- template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
35
+ template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template|api_instance)\(')
36
+ micro_step_dict = {}
34
37
 
35
38
  @staticmethod
36
- def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
39
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
37
40
  """
38
41
  GraphBuilder的对外提供的构图方法
39
42
  Args:
@@ -41,75 +44,38 @@ class GraphBuilder:
41
44
  data_path: dump.json路径
42
45
  stack_path: stack.json路径
43
46
  model_name: 模型名字,依赖外部输入
44
- complete_stack: 完整的堆栈信息
45
47
  Returns: Graph,代表图的数据结构
46
48
  """
47
- construct_dict = load_json(construct_path)
49
+ construct_dict, micro_step_dict = load_construct_json(construct_path)
50
+ if not construct_dict:
51
+ logger.error("The content of 'construct.json' is empty, failed to build graph. "
52
+ "When dumping data, it is necessary to select level L0 or mix in order to "
53
+ "collect model structure data, that is, the content of 'construct.json' is not empty.")
54
+ raise RuntimeError
55
+ GraphBuilder.micro_step_dict = micro_step_dict
48
56
  dump_dict = load_json(data_path)
49
57
  stack_dict = load_stack_json(stack_path)
50
- if not complete_stack:
51
- GraphBuilder._simplify_stack(stack_dict)
52
58
  data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
53
- graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
59
+ graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict,
60
+ micro_step_num=micro_step_dict.get(Const.MEGATRON_MICRO_STEP_NUMBER))
54
61
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
62
+ GraphBuilder._handle_recompute(graph)
55
63
  GraphBuilder._collect_apis_between_modules(graph)
56
64
  GraphBuilder._add_parameters_grad(graph, data_dict)
57
65
  return graph
58
66
 
59
67
  @staticmethod
60
- def to_json(filename, config):
61
- """
62
- 将graph导出成.vis文件的接口
63
- """
64
- result = {}
68
+ def to_db(filename, config):
69
+ config.graph_n.step = config.step
70
+ config.graph_n.rank = config.rank
71
+ config.graph_n.compare_mode = config.compare_mode
72
+ node_to_db(config.graph_n, filename)
65
73
  if config.graph_b:
66
- result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
67
- result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
68
- else:
69
- result = config.graph_n.to_dict(config.compare_mode)
70
- if config.tool_tip:
71
- result[GraphConst.JSON_TIP_KEY] = config.tool_tip
72
- if config.node_colors:
73
- result[GraphConst.COLORS] = config.node_colors
74
- if config.micro_steps:
75
- result[GraphConst.MICRO_STEPS] = config.micro_steps
76
- if config.task:
77
- result[GraphConst.JSON_TASK_KEY] = config.task
78
- result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
79
- save_json(filename, result, indent=4)
80
-
81
- @staticmethod
82
- def _simplify_stack(stack_dict):
83
- """
84
- 精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
85
-
86
- 例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
87
- 保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
88
-
89
- 例如Api Tensor.__iadd__.4.forward,堆栈为:
90
- "File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
91
- "File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
92
- 匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
93
- """
94
- module_pattern = re.compile(op_patterns[0])
95
- for dump_name, stack_list in stack_dict.items():
96
- if not isinstance(stack_list, list):
97
- continue
98
- if module_pattern.match(dump_name):
99
- parts = dump_name.split(Const.SEP)
100
- if len(parts) < abs(Const.LAYER_NAME_INDEX):
101
- continue
102
- module_name = parts[Const.LAYER_NAME_INDEX]
103
- for stack in stack_list:
104
- if re.search(module_name + r'\(', stack):
105
- stack_list = [stack]
106
- break
107
- else:
108
- for index, stack in enumerate(stack_list):
109
- if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
110
- stack_list = [stack_list[index + 1]]
111
- break
112
- stack_dict[dump_name] = stack_list
74
+ config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
75
+ config.graph_b.step = config.step
76
+ config.graph_b.rank = config.rank
77
+ node_to_db(config.graph_b, filename)
78
+ config_to_db(config, filename)
113
79
 
114
80
  @staticmethod
115
81
  def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
@@ -132,10 +98,47 @@ class GraphBuilder:
132
98
  return new_upnode_id
133
99
  return upnode_id
134
100
 
101
+ @staticmethod
102
+ def _handle_backward_inplace(construct_dict, sub_node_id, up_node_id):
103
+ """
104
+ 如果当前backward节点的父层级信息不等于其父级节点的层级信息,则尝试从同名的forward节点寻找父级节点
105
+ 主要针对的场景:inplace层会无法触发backward hook导致反向层级错误
106
+
107
+ example:
108
+ 正确的层级关系:
109
+ 父层:Module.layer4.1.BasicBlock.backward.0的层级信息为Module.layer4.1
110
+ 子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
111
+
112
+ 错误的层级关系:
113
+ 父层:Module.layer4.1.relu.ReLU.backward.1的层级信息为Module.layer4.1.relu
114
+ 子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
115
+ """
116
+ if GraphBuilder.backward_pattern.search(sub_node_id) and up_node_id:
117
+ sub_split = sub_node_id.split(Const.SEP)
118
+ if len(sub_split) < 5:
119
+ return up_node_id
120
+ up_split = up_node_id.split(Const.SEP)
121
+ if len(up_split) < 4:
122
+ return up_node_id
123
+ sub_node_prefix = Const.SEP.join(sub_split[:-4])
124
+ up_node_prefix = Const.SEP.join(up_split[:-3])
125
+ if sub_node_prefix != up_node_prefix:
126
+ forward_sub_node_id = GraphBuilder.backward_pattern.sub(r".forward.\2", sub_node_id)
127
+ if forward_sub_node_id in construct_dict:
128
+ forward_up_node_id = construct_dict.get(forward_sub_node_id)
129
+ # forward_up_node_id ---> null
130
+ if not forward_up_node_id:
131
+ return forward_up_node_id
132
+ new_up_node_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_up_node_id)
133
+ if new_up_node_id in construct_dict:
134
+ return new_up_node_id
135
+ return up_node_id
136
+
135
137
  @staticmethod
136
138
  def _init_nodes(graph, construct_dict, data_dict, stack_dict):
137
139
  for subnode_id, upnode_id in construct_dict.items():
138
- upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
140
+ upnode_id = GraphBuilder._handle_backward_inplace(construct_dict, subnode_id, upnode_id) if upnode_id \
141
+ else GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
139
142
  if upnode_id:
140
143
  upnode_op = NodeOp.get_node_op(upnode_id)
141
144
  upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
@@ -171,6 +174,8 @@ class GraphBuilder:
171
174
  node_stack_info = forward_node.stack_info if forward_node \
172
175
  else ['This backward node cannot find the forward node and cannot retrieve stack information.']
173
176
  node.stack_info = node_stack_info
177
+ if GraphBuilder.micro_step_dict:
178
+ node.micro_step_id = GraphBuilder.micro_step_dict.get(node.id, 0)
174
179
  # 添加节点
175
180
  node.add_upnode(upnode)
176
181
  return node
@@ -230,6 +235,8 @@ class GraphBuilder:
230
235
  node.upnode = api_collection_node
231
236
  api_collection_node.upnode = graph.root
232
237
  output.append(api_collection_node)
238
+ if temp_nodes[0].micro_step_id is not None:
239
+ api_collection_node.micro_step_id = temp_nodes[0].micro_step_id
233
240
  else:
234
241
  # 如果连续的api节点不足2个,将它们原样添加到输出列表
235
242
  output.extend(temp_nodes)
@@ -276,10 +283,128 @@ class GraphBuilder:
276
283
  # 更新数据
277
284
  graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
278
285
 
286
+ @staticmethod
287
+ def _handle_recompute(graph):
288
+ """
289
+ 1. 通过_get_recompute_map获得重计算节点映射recompute_map: dict(node_id: node_id_prefix)
290
+ 2. 通过_get_no_recompute_map获得非重计算节点映射no_recompute_map: dict(node_id_prefix: list(node_id))
291
+ 3. 遍历recompute_map,通过node_id_prefix与no_recompute_map建立连接,通过非重计算节点找到自身的父节点
292
+ """
293
+ recompute_map, recompute_id_map = GraphBuilder._get_recompute_map(graph.root.subnodes)
294
+ if not recompute_map:
295
+ return
296
+ id_prefixes = set(recompute_map.values())
297
+ no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
298
+ if not no_recompute_map:
299
+ return
300
+ # 深拷贝非重计算节点字典用于反向模式
301
+ no_recompute_ids_b = copy.deepcopy(no_recompute_map)
302
+
303
+ del_indexes = []
304
+ for node_id, id_prefix in recompute_map.items():
305
+ if id_prefix not in no_recompute_map:
306
+ continue
307
+ node_list = no_recompute_map.get(id_prefix) if GraphBuilder.forward_pattern.search(node_id) else \
308
+ no_recompute_ids_b.get(id_prefix)
309
+ if not node_list:
310
+ continue
311
+ no_recompute_node = node_list.pop()
312
+ recompute_node = graph.node_map.get(node_id)
313
+ if not recompute_node:
314
+ continue
315
+ # 通过非重计算forward节点的父节点,找到对应的backward父节点
316
+ new_up_node = graph.node_map.get(
317
+ GraphBuilder.forward_pattern.sub(r".backward.\2", no_recompute_node.upnode.id))
318
+ if not new_up_node:
319
+ continue
320
+
321
+ # 更新节点连接关系
322
+ recompute_node.upnode = new_up_node
323
+ new_up_node.subnodes.append(recompute_node)
324
+
325
+ del_indexes.append(recompute_id_map.get(node_id))
326
+
327
+ # 从后往前删除graph首层中已更新父节点的重计算节点
328
+ del_indexes.sort(reverse=True)
329
+ for index in del_indexes:
330
+ if 0 <= index <= len(graph.root.subnodes):
331
+ del graph.root.subnodes[index]
332
+
333
+ @staticmethod
334
+ def _get_recompute_map(node_list: list):
335
+ """
336
+ 找到graph首层的重计算层
337
+
338
+ return: dict(node_id: node_id_prefix), dict(node_id: index)
339
+
340
+ example:
341
+ {Module.0.module.decoder.layers.0.TransformerLayer.forward.4: Module.0.module.decoder.layers.0.TransformerLayer}
342
+ """
343
+ recompute_map = {}
344
+ recompute_id_map = {}
345
+ node_id_set = set([node.id for node in node_list])
346
+ node_id_cache = set()
347
+ for i, node in enumerate(node_list):
348
+ if NodeOp.get_node_op(node.id) != NodeOp.module:
349
+ continue
350
+ id_segments = node.id.split(Const.SEP)
351
+ prefix = Const.SEP.join(id_segments[:-2])
352
+ if node.id in node_id_cache:
353
+ recompute_map[node.id] = prefix
354
+ recompute_id_map[node.id] = i
355
+ continue
356
+ is_recompute = GraphBuilder._is_recompute_node_id(id_segments)
357
+ if not is_recompute:
358
+ continue
359
+ # 重计算层必然是一组对应的前反向节点
360
+ id_segments[-2] = Const.BACKWARD if id_segments[-2] == Const.FORWARD else Const.FORWARD
361
+ relative_node_id = Const.SEP.join(id_segments)
362
+ if relative_node_id in node_id_set:
363
+ recompute_map[node.id] = prefix
364
+ recompute_id_map[node.id] = i
365
+ # 对应节点id放入缓存避免后续重复判断
366
+ node_id_cache.add(relative_node_id)
367
+ return recompute_map, recompute_id_map
368
+
369
+ @staticmethod
370
+ def _is_recompute_node_id(id_segments):
371
+ """
372
+ 非重计算首层节点命名必然是:Module/Cell.{number(可选)}.module_name.{number(可选)}.class_name.forward/backward.number
373
+ 如果不符合,则判断为重计算节点
374
+ """
375
+ if len(id_segments) > 7:
376
+ return True
377
+ if len(id_segments) == 7 and not (id_segments[1].isdigit() and id_segments[3].isdigit()):
378
+ return True
379
+ if len(id_segments) == 6 and not id_segments[1].isdigit():
380
+ return True
381
+ return False
382
+
383
+ @staticmethod
384
+ def _get_no_recompute_map(graph, recompute_id_prefixes):
385
+ """
386
+ 寻找与重计算层id前缀相同的非重计算forward层,按顺序排列,重计算层按照顺序使用非重计算forward层的父节点对应的backward节点
387
+
388
+ return: dict(node_id_prefix: list(node_id))
389
+ """
390
+ no_recompute_map = {}
391
+ for node_id, node in graph.node_map.items():
392
+ if NodeOp.get_node_op(node_id) == NodeOp.module and GraphBuilder.forward_pattern.search(node_id):
393
+ if not node.upnode or node.upnode.id == graph.root.id:
394
+ continue
395
+ id_prefix = GraphBuilder.forward_pattern.sub('', node_id)
396
+ if id_prefix not in recompute_id_prefixes:
397
+ continue
398
+ no_recompute_map.setdefault(id_prefix, []).append(node)
399
+ for node_list in no_recompute_map.values():
400
+ # 方便按顺序pop弹出
401
+ node_list.reverse()
402
+ return no_recompute_map
403
+
279
404
 
280
405
  class GraphExportConfig:
281
406
  def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
282
- overflow_check=False, compare_mode=None):
407
+ overflow_check=False, compare_mode=None, step=0, rank=0, step_list=None, rank_list=None):
283
408
  self.graph_n = graph_n
284
409
  self.graph_b = graph_b
285
410
  self.tool_tip = tool_tip
@@ -288,6 +413,10 @@ class GraphExportConfig:
288
413
  self.task = task
289
414
  self.overflow_check = overflow_check
290
415
  self.compare_mode = compare_mode
416
+ self.step = step
417
+ self.rank = rank
418
+ self.step_list = step_list
419
+ self.rank_list = rank_list
291
420
 
292
421
 
293
422
  @dataclass