mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
  3. msprobe/README.md +7 -5
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +105 -27
  7. msprobe/core/common/framework_adapter.py +7 -6
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/utils.py +14 -3
  10. msprobe/core/compare/find_first/analyzer.py +8 -7
  11. msprobe/core/compare/find_first/graph.py +11 -3
  12. msprobe/core/compare/find_first/utils.py +2 -1
  13. msprobe/core/compare/highlight.py +13 -6
  14. msprobe/core/compare/multiprocessing_compute.py +17 -10
  15. msprobe/core/compare/utils.py +14 -5
  16. msprobe/core/data_dump/data_collector.py +18 -21
  17. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  18. msprobe/core/data_dump/json_writer.py +18 -8
  19. msprobe/core/data_dump/scope.py +4 -6
  20. msprobe/core/hook_manager.py +37 -3
  21. msprobe/core/service.py +18 -5
  22. msprobe/core/single_save/single_comparator.py +16 -3
  23. msprobe/docs/01.installation.md +7 -5
  24. msprobe/docs/02.config_introduction.md +14 -1
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  29. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  30. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  31. msprobe/docs/19.monitor.md +2 -0
  32. msprobe/docs/21.visualization_PyTorch.md +15 -80
  33. msprobe/docs/22.visualization_MindSpore.md +20 -104
  34. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  35. msprobe/docs/25.tool_function_introduction.md +1 -0
  36. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  37. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  38. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  40. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  41. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  42. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  43. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  44. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  45. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  46. msprobe/mindspore/cell_processor.py +33 -5
  47. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  48. msprobe/mindspore/compare/utils.py +1 -2
  49. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  50. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  51. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  52. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  53. msprobe/msprobe.py +6 -4
  54. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  55. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  58. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  59. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  60. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  61. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  62. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  63. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  64. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  65. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  66. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  67. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  68. msprobe/pytorch/attl_manager.py +65 -0
  69. msprobe/pytorch/common/utils.py +22 -2
  70. msprobe/pytorch/compare/utils.py +3 -3
  71. msprobe/pytorch/debugger/debugger_config.py +10 -0
  72. msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
  73. msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
  74. msprobe/pytorch/hook_module/api_register.py +6 -1
  75. msprobe/pytorch/monitor/module_hook.py +28 -9
  76. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  77. msprobe/pytorch/pt_config.py +57 -2
  78. msprobe/pytorch/pytorch_service.py +11 -2
  79. msprobe/visualization/builder/graph_builder.py +170 -64
  80. msprobe/visualization/builder/graph_merger.py +0 -1
  81. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  82. msprobe/visualization/db_utils.py +25 -2
  83. msprobe/visualization/graph/base_node.py +0 -24
  84. msprobe/visualization/graph/graph.py +5 -14
  85. msprobe/visualization/graph_service.py +29 -53
  86. msprobe/visualization/utils.py +11 -1
  87. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  88. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  89. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  90. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -14,14 +14,14 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import re
17
+ import copy
17
18
  from dataclasses import dataclass
18
19
 
19
20
  from msprobe.core.common.const import Const
20
- from msprobe.core.common.file_utils import load_json, save_json
21
+ from msprobe.core.common.file_utils import load_json, load_construct_json
21
22
  from msprobe.core.common.utils import load_stack_json
22
23
  from msprobe.core.common.log import logger
23
24
  from msprobe.visualization.builder.msprobe_adapter import get_input_output
24
- from msprobe.visualization.builder.msprobe_adapter import op_patterns
25
25
  from msprobe.visualization.graph.graph import Graph
26
26
  from msprobe.visualization.graph.node_op import NodeOp
27
27
  from msprobe.visualization.utils import GraphConst
@@ -33,9 +33,10 @@ class GraphBuilder:
33
33
  forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
34
34
  # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
35
35
  template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template|api_instance)\(')
36
+ micro_step_dict = {}
36
37
 
37
38
  @staticmethod
38
- def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
39
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
39
40
  """
40
41
  GraphBuilder的对外提供的构图方法
41
42
  Args:
@@ -43,48 +44,26 @@ class GraphBuilder:
43
44
  data_path: dump.json路径
44
45
  stack_path: stack.json路径
45
46
  model_name: 模型名字,依赖外部输入
46
- complete_stack: 完整的堆栈信息
47
47
  Returns: Graph,代表图的数据结构
48
48
  """
49
- construct_dict = load_json(construct_path)
49
+ construct_dict, micro_step_dict = load_construct_json(construct_path)
50
50
  if not construct_dict:
51
51
  logger.error("The content of 'construct.json' is empty, failed to build graph. "
52
52
  "When dumping data, it is necessary to select level L0 or mix in order to "
53
53
  "collect model structure data, that is, the content of 'construct.json' is not empty.")
54
54
  raise RuntimeError
55
+ GraphBuilder.micro_step_dict = micro_step_dict
55
56
  dump_dict = load_json(data_path)
56
57
  stack_dict = load_stack_json(stack_path)
57
- if not complete_stack:
58
- GraphBuilder._simplify_stack(stack_dict)
59
58
  data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
60
- graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
59
+ graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict,
60
+ micro_step_num=micro_step_dict.get(Const.MEGATRON_MICRO_STEP_NUMBER))
61
61
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
62
+ GraphBuilder._handle_recompute(graph)
62
63
  GraphBuilder._collect_apis_between_modules(graph)
63
64
  GraphBuilder._add_parameters_grad(graph, data_dict)
64
65
  return graph
65
66
 
66
- @staticmethod
67
- def to_json(filename, config):
68
- """
69
- 将graph导出成.vis文件的接口
70
- """
71
- result = {}
72
- if config.graph_b:
73
- result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode)
74
- result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode)
75
- else:
76
- result = config.graph_n.to_dict(config.compare_mode)
77
- if config.tool_tip:
78
- result[GraphConst.JSON_TIP_KEY] = config.tool_tip
79
- if config.node_colors:
80
- result[GraphConst.COLORS] = config.node_colors
81
- if config.micro_steps:
82
- result[GraphConst.MICRO_STEPS] = config.micro_steps
83
- if config.task:
84
- result[GraphConst.JSON_TASK_KEY] = config.task
85
- result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
86
- save_json(filename, result, indent=4)
87
-
88
67
  @staticmethod
89
68
  def to_db(filename, config):
90
69
  config.graph_n.step = config.step
@@ -95,42 +74,10 @@ class GraphBuilder:
95
74
  config.graph_b.data_source = GraphConst.JSON_BENCH_KEY
96
75
  config.graph_b.step = config.step
97
76
  config.graph_b.rank = config.rank
77
+ config.graph_b.compare_mode = config.compare_mode
98
78
  node_to_db(config.graph_b, filename)
99
79
  config_to_db(config, filename)
100
80
 
101
- @staticmethod
102
- def _simplify_stack(stack_dict):
103
- """
104
- 精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
105
-
106
- 例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
107
- 保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
108
-
109
- 例如Api Tensor.__iadd__.4.forward,堆栈为:
110
- "File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
111
- "File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
112
- 匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
113
- """
114
- module_pattern = re.compile(op_patterns[0])
115
- for dump_name, stack_list in stack_dict.items():
116
- if not isinstance(stack_list, list):
117
- continue
118
- if module_pattern.match(dump_name):
119
- parts = dump_name.split(Const.SEP)
120
- if len(parts) < abs(Const.LAYER_NAME_INDEX):
121
- continue
122
- module_name = parts[Const.LAYER_NAME_INDEX]
123
- for stack in stack_list:
124
- if re.search(module_name + r'\(', stack):
125
- stack_list = [stack]
126
- break
127
- else:
128
- for index, stack in enumerate(stack_list):
129
- if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
130
- stack_list = [stack_list[index + 1]]
131
- break
132
- stack_dict[dump_name] = stack_list
133
-
134
81
  @staticmethod
135
82
  def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
136
83
  """
@@ -152,10 +99,47 @@ class GraphBuilder:
152
99
  return new_upnode_id
153
100
  return upnode_id
154
101
 
102
+ @staticmethod
103
+ def _handle_backward_inplace(construct_dict, sub_node_id, up_node_id):
104
+ """
105
+ 如果当前backward节点的父层级信息不等于其父级节点的层级信息,则尝试从同名的forward节点寻找父级节点
106
+ 主要针对的场景:inplace层会无法触发backward hook导致反向层级错误
107
+
108
+ example:
109
+ 正确的层级关系:
110
+ 父层:Module.layer4.1.BasicBlock.backward.0的层级信息为Module.layer4.1
111
+ 子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
112
+
113
+ 错误的层级关系:
114
+ 父层:Module.layer4.1.relu.ReLU.backward.1的层级信息为Module.layer4.1.relu
115
+ 子层:Module.layer4.1.conv2.Conv2d.backward.0的父层级信息为Module.layer4.1
116
+ """
117
+ if GraphBuilder.backward_pattern.search(sub_node_id) and up_node_id:
118
+ sub_split = sub_node_id.split(Const.SEP)
119
+ if len(sub_split) < 5:
120
+ return up_node_id
121
+ up_split = up_node_id.split(Const.SEP)
122
+ if len(up_split) < 4:
123
+ return up_node_id
124
+ sub_node_prefix = Const.SEP.join(sub_split[:-4])
125
+ up_node_prefix = Const.SEP.join(up_split[:-3])
126
+ if sub_node_prefix != up_node_prefix:
127
+ forward_sub_node_id = GraphBuilder.backward_pattern.sub(r".forward.\2", sub_node_id)
128
+ if forward_sub_node_id in construct_dict:
129
+ forward_up_node_id = construct_dict.get(forward_sub_node_id)
130
+ # forward_up_node_id ---> null
131
+ if not forward_up_node_id:
132
+ return forward_up_node_id
133
+ new_up_node_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_up_node_id)
134
+ if new_up_node_id in construct_dict:
135
+ return new_up_node_id
136
+ return up_node_id
137
+
155
138
  @staticmethod
156
139
  def _init_nodes(graph, construct_dict, data_dict, stack_dict):
157
140
  for subnode_id, upnode_id in construct_dict.items():
158
- upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
141
+ upnode_id = GraphBuilder._handle_backward_inplace(construct_dict, subnode_id, upnode_id) if upnode_id \
142
+ else GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
159
143
  if upnode_id:
160
144
  upnode_op = NodeOp.get_node_op(upnode_id)
161
145
  upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
@@ -191,6 +175,8 @@ class GraphBuilder:
191
175
  node_stack_info = forward_node.stack_info if forward_node \
192
176
  else ['This backward node cannot find the forward node and cannot retrieve stack information.']
193
177
  node.stack_info = node_stack_info
178
+ if GraphBuilder.micro_step_dict:
179
+ node.micro_step_id = GraphBuilder.micro_step_dict.get(node.id, 0)
194
180
  # 添加节点
195
181
  node.add_upnode(upnode)
196
182
  return node
@@ -250,6 +236,8 @@ class GraphBuilder:
250
236
  node.upnode = api_collection_node
251
237
  api_collection_node.upnode = graph.root
252
238
  output.append(api_collection_node)
239
+ if temp_nodes[0].micro_step_id is not None:
240
+ api_collection_node.micro_step_id = temp_nodes[0].micro_step_id
253
241
  else:
254
242
  # 如果连续的api节点不足2个,将它们原样添加到输出列表
255
243
  output.extend(temp_nodes)
@@ -296,6 +284,124 @@ class GraphBuilder:
296
284
  # 更新数据
297
285
  graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
298
286
 
287
+ @staticmethod
288
+ def _handle_recompute(graph):
289
+ """
290
+ 1. 通过_get_recompute_map获得重计算节点映射recompute_map: dict(node_id: node_id_prefix)
291
+ 2. 通过_get_no_recompute_map获得非重计算节点映射no_recompute_map: dict(node_id_prefix: list(node_id))
292
+ 3. 遍历recompute_map,通过node_id_prefix与no_recompute_map建立连接,通过非重计算节点找到自身的父节点
293
+ """
294
+ recompute_map, recompute_id_map = GraphBuilder._get_recompute_map(graph.root.subnodes)
295
+ if not recompute_map:
296
+ return
297
+ id_prefixes = set(recompute_map.values())
298
+ no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
299
+ if not no_recompute_map:
300
+ return
301
+ # 深拷贝非重计算节点字典用于反向模式
302
+ no_recompute_ids_b = copy.deepcopy(no_recompute_map)
303
+
304
+ del_indexes = []
305
+ for node_id, id_prefix in recompute_map.items():
306
+ if id_prefix not in no_recompute_map:
307
+ continue
308
+ node_list = no_recompute_map.get(id_prefix) if GraphBuilder.forward_pattern.search(node_id) else \
309
+ no_recompute_ids_b.get(id_prefix)
310
+ if not node_list:
311
+ continue
312
+ no_recompute_node = node_list.pop()
313
+ recompute_node = graph.node_map.get(node_id)
314
+ if not recompute_node:
315
+ continue
316
+ # 通过非重计算forward节点的父节点,找到对应的backward父节点
317
+ new_up_node = graph.node_map.get(
318
+ GraphBuilder.forward_pattern.sub(r".backward.\2", no_recompute_node.upnode.id))
319
+ if not new_up_node:
320
+ continue
321
+
322
+ # 更新节点连接关系
323
+ recompute_node.upnode = new_up_node
324
+ new_up_node.subnodes.append(recompute_node)
325
+
326
+ del_indexes.append(recompute_id_map.get(node_id))
327
+
328
+ # 从后往前删除graph首层中已更新父节点的重计算节点
329
+ del_indexes.sort(reverse=True)
330
+ for index in del_indexes:
331
+ if 0 <= index <= len(graph.root.subnodes):
332
+ del graph.root.subnodes[index]
333
+
334
+ @staticmethod
335
+ def _get_recompute_map(node_list: list):
336
+ """
337
+ 找到graph首层的重计算层
338
+
339
+ return: dict(node_id: node_id_prefix), dict(node_id: index)
340
+
341
+ example:
342
+ {Module.0.module.decoder.layers.0.TransformerLayer.forward.4: Module.0.module.decoder.layers.0.TransformerLayer}
343
+ """
344
+ recompute_map = {}
345
+ recompute_id_map = {}
346
+ node_id_set = set([node.id for node in node_list])
347
+ node_id_cache = set()
348
+ for i, node in enumerate(node_list):
349
+ if NodeOp.get_node_op(node.id) != NodeOp.module:
350
+ continue
351
+ id_segments = node.id.split(Const.SEP)
352
+ prefix = Const.SEP.join(id_segments[:-2])
353
+ if node.id in node_id_cache:
354
+ recompute_map[node.id] = prefix
355
+ recompute_id_map[node.id] = i
356
+ continue
357
+ is_recompute = GraphBuilder._is_recompute_node_id(id_segments)
358
+ if not is_recompute:
359
+ continue
360
+ # 重计算层必然是一组对应的前反向节点
361
+ id_segments[-2] = Const.BACKWARD if id_segments[-2] == Const.FORWARD else Const.FORWARD
362
+ relative_node_id = Const.SEP.join(id_segments)
363
+ if relative_node_id in node_id_set:
364
+ recompute_map[node.id] = prefix
365
+ recompute_id_map[node.id] = i
366
+ # 对应节点id放入缓存避免后续重复判断
367
+ node_id_cache.add(relative_node_id)
368
+ return recompute_map, recompute_id_map
369
+
370
+ @staticmethod
371
+ def _is_recompute_node_id(id_segments):
372
+ """
373
+ 非重计算首层节点命名必然是:Module/Cell.{number(可选)}.module_name.{number(可选)}.class_name.forward/backward.number
374
+ 如果不符合,则判断为重计算节点
375
+ """
376
+ if len(id_segments) > 7:
377
+ return True
378
+ if len(id_segments) == 7 and not (id_segments[1].isdigit() and id_segments[3].isdigit()):
379
+ return True
380
+ if len(id_segments) == 6 and not id_segments[1].isdigit():
381
+ return True
382
+ return False
383
+
384
+ @staticmethod
385
+ def _get_no_recompute_map(graph, recompute_id_prefixes):
386
+ """
387
+ 寻找与重计算层id前缀相同的非重计算forward层,按顺序排列,重计算层按照顺序使用非重计算forward层的父节点对应的backward节点
388
+
389
+ return: dict(node_id_prefix: list(node_id))
390
+ """
391
+ no_recompute_map = {}
392
+ for node_id, node in graph.node_map.items():
393
+ if NodeOp.get_node_op(node_id) == NodeOp.module and GraphBuilder.forward_pattern.search(node_id):
394
+ if not node.upnode or node.upnode.id == graph.root.id:
395
+ continue
396
+ id_prefix = GraphBuilder.forward_pattern.sub('', node_id)
397
+ if id_prefix not in recompute_id_prefixes:
398
+ continue
399
+ no_recompute_map.setdefault(id_prefix, []).append(node)
400
+ for node_list in no_recompute_map.values():
401
+ # 方便按顺序pop弹出
402
+ node_list.reverse()
403
+ return no_recompute_map
404
+
299
405
 
300
406
  class GraphExportConfig:
301
407
  def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
@@ -212,7 +212,6 @@ class BaseGraphMerger:
212
212
  if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True):
213
213
  same_flag = False
214
214
  if not same_flag:
215
- # {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]}
216
215
  data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params
217
216
  return data_types.get('input_data'), data_types.get('output_data')
218
217
 
@@ -28,7 +28,7 @@ op_patterns = [
28
28
  # NodeOp.module
29
29
  r'^(Module.|Cell.|optimizer|clip_grad)',
30
30
  # NodeOp.function_api
31
- r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
31
+ r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.|MindSpeed.)'
32
32
  ]
33
33
 
34
34
 
@@ -41,7 +41,7 @@ node_columns = {
41
41
  'overflow_level': TEXT,
42
42
  'micro_step_id': INTEGER_NOT_NULL,
43
43
  'matched_node_link': TEXT,
44
- 'stack_info': TEXT,
44
+ 'stack_id': TEXT,
45
45
  'parallel_merge_info': TEXT,
46
46
  'matched_distributed': TEXT,
47
47
  'modified': INTEGER_NOT_NULL,
@@ -65,6 +65,11 @@ config_columns = {
65
65
  'step_list': TEXT_NOT_NULL
66
66
  }
67
67
 
68
+ stack_columns = {
69
+ 'id': TEXT_PRIMARY_KEY,
70
+ 'stack_info': TEXT
71
+ }
72
+
68
73
  indexes = {
69
74
  "index1": ["step", "rank", "data_source", "up_node", "node_order"],
70
75
  "index2": ["step", "rank", "data_source", "node_name"],
@@ -197,19 +202,24 @@ def node_to_db(graph, db_name):
197
202
  create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns)
198
203
  insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns)
199
204
  data = []
205
+ stack_dict = {}
200
206
  for i, node in enumerate(graph.get_sorted_nodes()):
207
+ stack_info_text = json.dumps(node.stack_info)
208
+ if stack_info_text not in stack_dict:
209
+ stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict)
201
210
  data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value,
202
211
  node.upnode.id if node.upnode else '',
203
212
  json.dumps([node.id for node in node.subnodes]) if node.subnodes else '',
204
213
  node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL),
205
214
  node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link),
206
- json.dumps(node.stack_info),
215
+ stack_dict.get(stack_info_text),
207
216
  json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '',
208
217
  json.dumps(node.matched_distributed), 0,
209
218
  json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)),
210
219
  json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)),
211
220
  graph.data_source, graph.data_path, graph.step, graph.rank))
212
221
  to_db(db_name, create_table_sql, insert_sql, data)
222
+ stack_to_db(stack_dict, db_name)
213
223
 
214
224
 
215
225
  def config_to_db(config, db_name):
@@ -221,9 +231,22 @@ def config_to_db(config, db_name):
221
231
  to_db(db_name, create_table_sql, insert_sql, data)
222
232
 
223
233
 
234
+ def stack_to_db(stack_dict, db_name):
235
+ create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns)
236
+ insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns)
237
+ data = []
238
+ for stack_info_text, unique_id in stack_dict.items():
239
+ data.append((unique_id, stack_info_text))
240
+ to_db(db_name, create_table_sql, insert_sql, data)
241
+
242
+
224
243
  def get_graph_unique_id(graph):
225
244
  return f'{graph.data_source}_{graph.step}_{graph.rank}'
226
245
 
227
246
 
228
247
  def get_node_unique_id(graph, node):
229
248
  return f'{get_graph_unique_id(graph)}_{node.id}'
249
+
250
+
251
+ def get_stack_unique_id(graph, stack_dict):
252
+ return f'{get_graph_unique_id(graph)}_{len(stack_dict)}'
@@ -89,30 +89,6 @@ class BaseNode:
89
89
  self.matched_node_link = ancestors
90
90
  node.matched_node_link = ancestors
91
91
 
92
- def to_dict(self, compare_mode=None):
93
- """
94
- 输出数据
95
- """
96
- result = {
97
- 'id': self.id,
98
- 'node_type': self.op.value,
99
- 'output_data': format_node_data(self.output_data, self.id, compare_mode),
100
- 'input_data': format_node_data(self.input_data, self.id, compare_mode),
101
- 'upnode': self.upnode.id if self.upnode else 'None',
102
- 'subnodes': [node.id for node in self.subnodes],
103
- 'matched_node_link': self.matched_node_link,
104
- 'suggestions': self.suggestions,
105
- 'stack_info': self.stack_info
106
- }
107
- if self.micro_step_id is not None:
108
- result['micro_step_id'] = self.micro_step_id
109
- result['data'] = self.data
110
- if self.matched_distributed:
111
- result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
112
- if self.parallel_merge_info:
113
- result['parallel_merge_info'] = self.parallel_merge_info
114
- return result
115
-
116
92
  def get_ancestors(self):
117
93
  """
118
94
  获取节点所有祖先的列表
@@ -22,7 +22,7 @@ from msprobe.core.common.decorator import recursion_depth_decorator
22
22
 
23
23
 
24
24
  class Graph:
25
- def __init__(self, model_name, data_path='', dump_data=None):
25
+ def __init__(self, model_name, data_path='', dump_data=None, micro_step_num=None):
26
26
  self.node_map = {}
27
27
  self.node_id_map = {}
28
28
  self.add_node(NodeOp.module, model_name)
@@ -33,6 +33,7 @@ class Graph:
33
33
  self.step = 0
34
34
  self.rank = 0
35
35
  self.compare_mode = GraphConst.SUMMARY_COMPARE
36
+ self.micro_step_num = micro_step_num
36
37
 
37
38
  def __str__(self):
38
39
  infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
@@ -172,19 +173,6 @@ class Graph:
172
173
  """
173
174
  return self.node_map.get(node_id, None)
174
175
 
175
- def to_dict(self, compare_mode=None):
176
- """
177
- 用于数据输出
178
- """
179
- result = {}
180
- result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
181
- result[GraphConst.JSON_DATA_KEY] = self.data_path
182
- result[GraphConst.JSON_NODE_KEY] = {}
183
- for node_id in self.node_map:
184
- info = self.node_map.get(node_id).to_dict(compare_mode)
185
- result[GraphConst.JSON_NODE_KEY][node_id] = info
186
- return result
187
-
188
176
  def paging_by_micro_step(self, graph_other=None):
189
177
  """
190
178
  给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
@@ -203,6 +191,9 @@ class Graph:
203
191
  for sub_node in node.subnodes:
204
192
  propagate_micro_step_id(sub_node)
205
193
 
194
+ if self.micro_step_num is not None:
195
+ return self.micro_step_num + 1
196
+
206
197
  batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
207
198
  for batch_number, nodes in batches_n.items():
208
199
  for node in nodes: