mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from msprobe.visualization.utils import save_json_file, GraphConst
26
26
 
27
27
  class GraphBuilder:
28
28
  backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
+ forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
29
30
  # 匹配以大写字母开头,后接任意字母,并以Template(结尾
30
31
  template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
31
32
 
@@ -113,12 +114,17 @@ class GraphBuilder:
113
114
  如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
114
115
  """
115
116
  # 匹配以.backward.后跟一个或多个数字结尾的模式
116
- backward_pattern = r"(\.backward\.)(\d+)$"
117
- forward_pattern = r"(\.forward\.)(\d+)$"
118
- if re.search(backward_pattern, subnode_id) and not upnode_id:
119
- forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
117
+ if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id:
118
+ forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id))
120
119
  if forward_upnode_id:
121
- new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
120
+ new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id)
121
+ if new_upnode_id in construct_dict:
122
+ return new_upnode_id
123
+ # 匹配以.backward结尾的节点
124
+ if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id:
125
+ forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD))
126
+ if forward_upnode_id:
127
+ new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD)
122
128
  if new_upnode_id in construct_dict:
123
129
  return new_upnode_id
124
130
  return upnode_id
@@ -148,6 +154,8 @@ class GraphBuilder:
148
154
  input_data, output_data = get_input_output(node_data, node.id)
149
155
  # 更新数据
150
156
  node.set_input_output(input_data, output_data)
157
+ if GraphConst.BATCH_P2P in name:
158
+ GraphBuilder._extract_batch_p2p_info(node, node_data)
151
159
  # 反向节点使用对应前向节点的堆栈信息
152
160
  # 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
153
161
  if (not node_stack_info and
@@ -164,6 +172,24 @@ class GraphBuilder:
164
172
  node.add_upnode(upnode)
165
173
  return node
166
174
 
175
+ @staticmethod
176
+ def _is_valid_batch_p2p_output(param_list):
177
+ if not isinstance(param_list, list) or not param_list:
178
+ return False
179
+ if not isinstance(param_list[0], list) or not param_list[0]:
180
+ return False
181
+ return True
182
+
183
+ @staticmethod
184
+ def _extract_batch_p2p_info(node, node_data):
185
+ param_list = node_data.get(Const.OUTPUT, [])
186
+ # 数据格式:"output": [[{param1}, {param2}, ...]]
187
+ if GraphBuilder._is_valid_batch_p2p_output(param_list):
188
+ for param in param_list[0]:
189
+ info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
190
+ GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
191
+ node.batch_p2p_info.append(info)
192
+
167
193
  @staticmethod
168
194
  def _collect_apis_between_modules(graph):
169
195
  """
@@ -23,7 +23,7 @@ from msprobe.core.compare.acc_compare import ModeConfig
23
23
  # 用于将节点名字解析成对应的NodeOp的规则
24
24
  op_patterns = [
25
25
  # NodeOp.module
26
- r'^(Module.|Cell.)',
26
+ r'^(Module.|Cell.|optimizer|clip_grad)',
27
27
  # NodeOp.function_api
28
28
  r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
29
29
  ]
@@ -57,8 +57,8 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
57
57
  from msprobe.pytorch.compare.pt_compare import PTComparator
58
58
  return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
59
59
  else:
60
- from msprobe.mindspore.compare.ms_compare import MSComparator
61
- ms_comparator = MSComparator(mode_config)
60
+ from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
61
+ ms_comparator = MSComparator(mode_config, MappingConfig())
62
62
  ms_comparator.cross_frame = is_cross_frame
63
63
  return ms_comparator.do_multi_process(dump_path_param, csv_path)
64
64
 
@@ -120,11 +120,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
120
120
  return True
121
121
 
122
122
 
123
- def format_node_data(data_dict):
123
+ def format_node_data(data_dict, node_id=None):
124
124
  """
125
- 批量进行节点数据的输出
125
+ 删除节点数据中不需要展示的字段
126
126
  """
127
127
  del_list = ['requires_grad', 'full_op_name']
128
+ if node_id and GraphConst.BATCH_P2P in node_id:
129
+ del_list.extend(['op', 'peer', 'tag', 'group_id'])
128
130
  for _, value in data_dict.items():
129
131
  if not isinstance(value, dict):
130
132
  continue
@@ -34,6 +34,7 @@ class BaseNode:
34
34
  self.micro_step_id = None
35
35
  self.overflow_level = None
36
36
  self.matched_distributed = {}
37
+ self.batch_p2p_info = []
37
38
 
38
39
  def __str__(self):
39
40
  info = f'id:\t{self.id}'
@@ -92,8 +93,8 @@ class BaseNode:
92
93
  result = {
93
94
  'id': self.id,
94
95
  'node_type': self.op.value,
95
- 'output_data': format_node_data(self.output_data),
96
- 'input_data': format_node_data(self.input_data),
96
+ 'output_data': format_node_data(self.output_data, self.id),
97
+ 'input_data': format_node_data(self.input_data, self.id),
97
98
  'upnode': self.upnode.id if self.upnode else 'None',
98
99
  'subnodes': [node.id for node in self.subnodes],
99
100
  'matched_node_link': self.matched_node_link,
@@ -107,6 +107,15 @@ class DistributedAnalyzer:
107
107
  return None, None
108
108
  return group_ranks, group_id
109
109
 
110
+ @staticmethod
111
+ def _get_batch_group_info(node, rank):
112
+ for data in node.input_data.values():
113
+ group_id = data.get('group_id')
114
+ if group_id is not None:
115
+ return group_id
116
+ logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
117
+ return None
118
+
110
119
  def distributed_match(self):
111
120
  for rank, graph in self.graphs.items():
112
121
  nodes = graph.node_map
@@ -115,7 +124,9 @@ class DistributedAnalyzer:
115
124
  if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
116
125
  continue
117
126
  api_name, distributed_type = self._get_distributed_name_and_type(node_id)
118
- if distributed_type == DistributedType.P2P:
127
+ if api_name == GraphConst.BATCH_P2P:
128
+ self._batch_p2p_match(node, rank)
129
+ elif distributed_type == DistributedType.P2P:
119
130
  self._p2p_match(node, rank, api_name)
120
131
  else:
121
132
  self._collective_match(node, rank, api_name)
@@ -138,12 +149,16 @@ class DistributedAnalyzer:
138
149
  for rank, graph in self.graphs.items():
139
150
  group_count = {}
140
151
  group_info = {}
152
+ batch_p2p_count = {}
141
153
  nodes = graph.node_map
142
154
  for node_id, node in nodes.items():
143
155
  if not node_id.startswith(Const.DISTRIBUTED):
144
156
  continue
145
157
  api_name, distributed_type = self._get_distributed_name_and_type(node_id)
146
- if distributed_type == DistributedType.P2P:
158
+ if api_name == GraphConst.BATCH_P2P:
159
+ self._make_batch_p2p_mapping(node, rank, batch_p2p_count)
160
+ continue
161
+ elif distributed_type == DistributedType.P2P:
147
162
  config_info = self.config.get(api_name)
148
163
  target_rank = self._get_target_rank(node, rank, config_info[1])
149
164
  if target_rank is None:
@@ -162,7 +177,32 @@ class DistributedAnalyzer:
162
177
  unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
163
178
  group_info[unique_group_id] = node_id
164
179
  group_info[node_id] = unique_group_id
165
- self.group_node_mapping[rank] = group_info
180
+ if rank not in self.group_node_mapping:
181
+ self.group_node_mapping[rank] = {}
182
+ self.group_node_mapping[rank].update(group_info)
183
+
184
+ def _make_batch_p2p_mapping(self, node, rank, batch_p2p_count):
185
+ """
186
+ 给batch_isend_irecv接口的每个p2p内容赋予唯一标识
187
+ """
188
+ if rank not in self.group_node_mapping:
189
+ self.group_node_mapping[rank] = {}
190
+ params = []
191
+ for info_dict in node.batch_p2p_info:
192
+ op = info_dict.get(GraphConst.OP)
193
+ target_rank = info_dict.get(GraphConst.PEER)
194
+ if op is None or target_rank is None:
195
+ logger.warning('Cannot get param op or peer.')
196
+ continue
197
+ group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
198
+ Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
199
+ batch_p2p_count[group_id] = batch_p2p_count.get(group_id, 0) + 1
200
+ # 例如: isend_rank0_5a4d31ad765260ba50eb190f1f9fd163_1
201
+ unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(batch_p2p_count.get(group_id))
202
+ params.append(unique_group_id)
203
+ self.group_node_mapping.get(rank)[unique_group_id] = node.id
204
+ if params:
205
+ self.group_node_mapping.get(rank)[node.id] = params
166
206
 
167
207
  def _get_distributed_name_and_type(self, node_id):
168
208
  if Const.SEP not in node_id:
@@ -316,3 +356,40 @@ class DistributedAnalyzer:
316
356
  if nodes_info:
317
357
  matched_distributed['nodes_info'] = nodes_info
318
358
  node.matched_distributed = matched_distributed
359
+
360
+ def _batch_p2p_match(self, node, rank):
361
+ """
362
+ 批量点对点匹配
363
+
364
+ 针对torch.distributed.batch_isend_irecv接口,其入参是一个包含点对点通信信息的集合,需要遍历集合对每个点对点通信信息进行匹配
365
+ :param node: 当前集体通信节点
366
+ :param rank: 当前节点所属rank
367
+ :return:
368
+ """
369
+ unique_group_ids = self.group_node_mapping.get(rank, {}).get(node.id)
370
+ if not unique_group_ids:
371
+ return
372
+ matched_distributed = [] if len(unique_group_ids) > 1 else {}
373
+ for unique_group_id in unique_group_ids:
374
+ try:
375
+ id_info = unique_group_id.split(Const.REPLACEMENT_CHARACTER)
376
+ api_name = id_info[0]
377
+ target_api_name = self.config.get(api_name)[0]
378
+ target_rank = int(id_info[1].replace(Const.RANK, ''))
379
+ except Exception as e:
380
+ logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.')
381
+ continue
382
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
383
+ if not target_node:
384
+ continue
385
+ communications_type = self.config.get(api_name)[2]
386
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
387
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
388
+ matched_info = {
389
+ 'communications_type': communications_type,
390
+ 'nodes_info': {target_rank: [str(index), target_node.id]}
391
+ }
392
+ matched_distributed.append(matched_info) if isinstance(matched_distributed, list) \
393
+ else matched_distributed.update(matched_info)
394
+ if matched_distributed:
395
+ node.matched_distributed = matched_distributed
@@ -16,6 +16,7 @@
16
16
  from enum import Enum
17
17
  import re
18
18
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class NodeOp(Enum):
@@ -32,8 +33,9 @@ class NodeOp(Enum):
32
33
  for op in NodeOp:
33
34
  index = op.value
34
35
  if index < 0 or index >= len(op_patterns):
35
- raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
36
+ continue
36
37
  pattern = op_patterns[index]
37
38
  if re.match(pattern, node_name):
38
39
  return op
39
- raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
40
+ logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.")
41
+ return NodeOp.module
@@ -16,8 +16,8 @@
16
16
  import os
17
17
  import time
18
18
  import json
19
- from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker,
20
- check_file_or_directory_path)
19
+ from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
20
+ check_file_or_directory_path, load_json)
21
21
  from msprobe.core.common.const import FileCheckConst, Const
22
22
  from msprobe.core.common.utils import CompareException
23
23
  from msprobe.core.overflow_check.checker import AnomalyDetector
@@ -220,8 +220,7 @@ def _graph_service_parser(parser):
220
220
 
221
221
 
222
222
  def _graph_service_command(args):
223
- with FileOpen(args.input_path, "r") as file:
224
- input_param = json.load(file)
223
+ input_param = load_json(args.input_path)
225
224
  npu_path = input_param.get("npu_path")
226
225
  bench_path = input_param.get("bench_path")
227
226
  check_file_or_directory_path(npu_path, isdir=True)
@@ -155,6 +155,7 @@ class GraphConst:
155
155
  SUMMARY_COMPARE = 0
156
156
  MD5_COMPARE = 1
157
157
  REAL_DATA_COMPARE = 2
158
+ STRUCTURE_COMPARE = 3
158
159
  JSON_NPU_KEY = 'NPU'
159
160
  JSON_BENCH_KEY = 'Bench'
160
161
  JSON_TIP_KEY = 'ToolTip'
@@ -200,13 +201,15 @@ class GraphConst:
200
201
  DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING = {
201
202
  Const.ALL: REAL_DATA_COMPARE,
202
203
  Const.SUMMARY: SUMMARY_COMPARE,
203
- Const.MD5: MD5_COMPARE
204
+ Const.MD5: MD5_COMPARE,
205
+ Const.STRUCTURE: STRUCTURE_COMPARE
204
206
  }
205
207
 
206
208
  GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING = {
207
209
  REAL_DATA_COMPARE: Const.ALL,
208
210
  SUMMARY_COMPARE: Const.SUMMARY,
209
- MD5_COMPARE: Const.MD5
211
+ MD5_COMPARE: Const.MD5,
212
+ STRUCTURE_COMPARE: Const.STRUCTURE
210
213
  }
211
214
 
212
215
  RANKS = 'ranks'
@@ -215,3 +218,8 @@ class GraphConst:
215
218
 
216
219
  SRC = 'src'
217
220
  DST = 'dst'
221
+
222
+ BATCH_P2P = 'batch_isend_irecv'
223
+ OP = 'op'
224
+ PEER = 'peer'
225
+ GROUP_ID = 'group_id'