mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -170,6 +170,16 @@ def gen_op_item(op_data, op_name):
170
170
  elif op_item.get('type') == 'slice':
171
171
  op_item['dtype'] = op_data.get('type')
172
172
  op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
173
+ elif op_item.get('type') == 'ellipsis':
174
+ op_item['dtype'] = op_data.get('type')
175
+ op_item['shape'] = '[]'
176
+ for i in params:
177
+ op_item[i] = op_data.get('value')
178
+ elif op_item.get('type') == 'torch.ProcessGroup':
179
+ op_item['dtype'] = op_data.get('type')
180
+ op_item['shape'] = '[]'
181
+ for i in params:
182
+ op_item[i] = str(op_data.get('group_ranks'))
173
183
  else:
174
184
  op_item['dtype'] = str(type(op_data.get('value')))
175
185
  op_item['shape'] = '[]'
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -40,6 +40,7 @@ class DataCollector:
40
40
  self.scope = ScopeFactory(self.config).build_scope()
41
41
  self.backward_module_names = {}
42
42
  self.optimizer_status = ""
43
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
43
44
  atexit.register(self.write_json)
44
45
 
45
46
  @property
@@ -54,6 +55,17 @@ class DataCollector:
54
55
  def check_scope_and_pid(scope, name, pid):
55
56
  return (not scope or scope.check(name)) and pid == os.getpid()
56
57
 
58
+ @staticmethod
59
+ def set_is_recomputable(data_info, is_recompute):
60
+ if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
61
+ data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
62
+
63
+ def reset_status(self):
64
+ self.optimizer_status = ""
65
+ self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
66
+ self.data_writer.reset_cache()
67
+ self.backward_module_names.clear()
68
+
57
69
  def if_return_forward_new_output(self):
58
70
  return self.data_processor.if_return_forward_new_output()
59
71
 
@@ -77,7 +89,7 @@ class DataCollector:
77
89
  logger.debug(msg)
78
90
  self.data_writer.update_data(data_info)
79
91
 
80
- def forward_input_data_collect(self, name, module, pid, module_input_output):
92
+ def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
81
93
  if self.config.task == Const.FREE_BENCHMARK:
82
94
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
83
95
  if self.check_scope_and_pid(self.scope, backward_name, pid):
@@ -87,37 +99,48 @@ class DataCollector:
87
99
  if not self.check_scope_and_pid(self.scope, name, pid):
88
100
  return
89
101
 
90
- data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
102
+ data_info = {}
103
+ if self.config.task != Const.STRUCTURE:
104
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
105
+ self.set_is_recomputable(data_info, is_recompute)
91
106
  if self.config.level == Const.LEVEL_L2:
92
107
  return
93
108
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
94
109
 
95
- def forward_output_data_collect(self, name, module, pid, module_input_output):
110
+ def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
96
111
  self.update_construct(name)
97
112
  if not self.check_scope_and_pid(self.scope, name, pid):
98
113
  return
99
114
 
100
- data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
115
+ data_info = {}
116
+ if self.config.task != Const.STRUCTURE:
117
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
118
+ self.set_is_recomputable(data_info, is_recompute)
101
119
  if self.config.level == Const.LEVEL_L2:
102
120
  return
103
121
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
104
122
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
105
123
 
106
- def forward_data_collect(self, name, module, pid, module_input_output):
124
+ def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
107
125
  self.update_construct(name)
108
126
  if not self.check_scope_and_pid(self.scope, name, pid):
109
127
  return
110
128
 
111
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
129
+ data_info = {}
130
+ if self.config.task != Const.STRUCTURE:
131
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
132
+ self.set_is_recomputable(data_info, is_recompute)
112
133
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
113
134
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
114
135
 
115
- def backward_data_collect(self, name, module, pid, module_input_output):
136
+ def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
116
137
  self.update_construct(name)
117
138
  if not self.check_scope_and_pid(self.scope, name, pid):
118
139
  return
119
140
 
120
- data_info = self.data_processor.analyze_backward(name, module, module_input_output)
141
+ data_info = {}
142
+ if self.config.task != Const.STRUCTURE:
143
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
121
144
  if self.config.level == Const.LEVEL_L2:
122
145
  return
123
146
  # 获取执行反向的模块名称
@@ -127,25 +150,34 @@ class DataCollector:
127
150
  self.backward_module_names[module_name] = True
128
151
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
129
152
 
130
- def backward_input_data_collect(self, name, module, pid, module_input_output):
153
+ def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
131
154
  self.update_construct(name)
132
155
  if not self.check_scope_and_pid(self.scope, name, pid):
133
156
  return
134
157
 
135
- data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
158
+ data_info = {}
159
+ if self.config.task != Const.STRUCTURE:
160
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
161
+ self.set_is_recomputable(data_info, is_recompute)
136
162
  self.handle_data(name, data_info)
137
163
 
138
- def backward_output_data_collect(self, name, module, pid, module_input_output):
164
+ def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
139
165
  self.update_construct(name)
140
166
  if not self.check_scope_and_pid(self.scope, name, pid):
141
167
  return
142
168
 
143
- data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
169
+ data_info = {}
170
+ if self.config.task != Const.STRUCTURE:
171
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
172
+ self.set_is_recomputable(data_info, is_recompute)
144
173
  self.handle_data(name, data_info)
145
174
 
146
175
  def update_construct(self, name):
147
176
  if self.config.level not in DataCollector.level_without_construct:
148
177
  if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
178
+ if self.optimizer_status_first_start[self.optimizer_status]:
179
+ self.data_writer.update_construct({self.optimizer_status: None})
180
+ self.optimizer_status_first_start[self.optimizer_status] = False
149
181
  self.data_writer.update_construct({name: self.optimizer_status})
150
182
  else:
151
183
  self.data_writer.update_construct({name: self.module_processor.api_parent_node})
@@ -183,3 +215,16 @@ class DataCollector:
183
215
 
184
216
  def fill_stack_tensor_data(self):
185
217
  self.data_writer.fill_stack_tensor_data()
218
+
219
+ def debug_data_collect_forward(self, variable, name_with_count):
220
+
221
+ data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
222
+ self.data_writer.update_debug({name_with_count: data_info})
223
+
224
+ def debug_data_collect_backward(self, variable, grad_name_with_count):
225
+ # prepare all None nested data structure
226
+ all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
227
+ self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
228
+
229
+ # register tensor backward hook
230
+ self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
@@ -17,6 +17,9 @@ import inspect
17
17
  import os
18
18
  from dataclasses import dataclass, is_dataclass
19
19
  from typing import Tuple, Dict, Optional, Any
20
+ from functools import partial
21
+ import copy
22
+ from typing import Union
20
23
 
21
24
  import numpy as np
22
25
 
@@ -87,7 +90,7 @@ class TensorStatInfo:
87
90
  class BaseDataProcessor:
88
91
  _recursive_key_stack = []
89
92
  special_type = (
90
- np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
93
+ np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
91
94
  bool, int, float, str, slice,
92
95
  type(Ellipsis)
93
96
  )
@@ -143,6 +146,37 @@ class BaseDataProcessor:
143
146
  else:
144
147
  return data
145
148
 
149
+ @staticmethod
150
+ def set_value_into_nested_structure(data_structure, indexes, value):
151
+ '''
152
+ Args:
153
+ data_structure: nested data structure
154
+ indexes: List
155
+ value: value to be set
156
+ '''
157
+ if not indexes:
158
+ raise ValueError("set_value_into_nested_structure failed: "
159
+ "indexes need to be non empty when set value to nested data structure")
160
+ current_level = data_structure
161
+ for i, index in enumerate(indexes):
162
+ valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
163
+ valid_for_dict = isinstance(current_level, dict) and index in current_level
164
+ is_last = i == len(indexes) - 1
165
+ if valid_for_dict or valid_for_list:
166
+ if is_last:
167
+ try:
168
+ current_level[index] = value
169
+ except Exception as e:
170
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
171
+ else:
172
+ try:
173
+ current_level = current_level[index]
174
+ except Exception as e:
175
+ raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
176
+ else:
177
+ raise ValueError("set_value_into_nested_structure failed: "
178
+ "invalid data_structure type or invalid index")
179
+
146
180
  @staticmethod
147
181
  def _convert_numpy_to_builtin(arg):
148
182
  type_mapping = {
@@ -183,8 +217,22 @@ class BaseDataProcessor:
183
217
  return single_arg
184
218
 
185
219
  @staticmethod
186
- def _analyze_numpy(value, numpy_type):
187
- return {"type": numpy_type, "value": value}
220
+ def _analyze_numpy(ndarray, numpy_type):
221
+ ndarray_json = {}
222
+ ndarray_json.update({'type': 'numpy.ndarray'})
223
+ ndarray_json.update({'dtype': str(ndarray.dtype)})
224
+ ndarray_json.update({'shape': ndarray.shape})
225
+ if ndarray.size > 0:
226
+ ndarray_json.update({"Max": np.max(ndarray).item()})
227
+ ndarray_json.update({"Min": np.min(ndarray).item()})
228
+ ndarray_json.update({"Mean": np.mean(ndarray).item()})
229
+ ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
230
+ else:
231
+ ndarray_json.update({"Max": None})
232
+ ndarray_json.update({"Min": None})
233
+ ndarray_json.update({"Mean": None})
234
+ ndarray_json.update({"Norm": None})
235
+ return ndarray_json
188
236
 
189
237
  @staticmethod
190
238
  def _get_allowed_data_mode(data_mode):
@@ -203,7 +251,7 @@ class BaseDataProcessor:
203
251
  return cls.special_type
204
252
 
205
253
  @classmethod
206
- def recursive_apply_transform(cls, args, transform, depth=0):
254
+ def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
207
255
  if depth > Const.MAX_DEPTH:
208
256
  logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
209
257
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
@@ -220,7 +268,7 @@ class BaseDataProcessor:
220
268
  return cls.apply_transform_dict(args_dict, transform, depth)
221
269
  elif isinstance(args, (list, tuple)):
222
270
  result_list = cls.apply_transform_list(args, transform, depth)
223
- return type(args)(result_list)
271
+ return result_list
224
272
  elif isinstance(args, dict):
225
273
  return cls.apply_transform_dict(args, transform, depth)
226
274
  elif args is not None:
@@ -228,12 +276,12 @@ class BaseDataProcessor:
228
276
  return None
229
277
  else:
230
278
  return None
231
-
279
+
232
280
  @classmethod
233
281
  def apply_transform_dict(cls, args, transform, depth):
234
282
  result_dict = {}
235
283
  for k, arg in args.items():
236
- cls._recursive_key_stack.append(str(k))
284
+ cls._recursive_key_stack.append(k)
237
285
  result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
238
286
  cls._recursive_key_stack.pop()
239
287
  return result_dict
@@ -242,11 +290,21 @@ class BaseDataProcessor:
242
290
  def apply_transform_list(cls, args, transform, depth):
243
291
  result_list = []
244
292
  for i, arg in enumerate(args):
245
- cls._recursive_key_stack.append(str(i))
293
+ cls._recursive_key_stack.append(i)
246
294
  result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
247
295
  cls._recursive_key_stack.pop()
248
296
  return result_list
249
297
 
298
+ @classmethod
299
+ def register_hook_single_element(cls, element, suffix_stack, hook_fn):
300
+ if cls.is_hookable_element(element):
301
+ indexes = copy.deepcopy(suffix_stack)
302
+ wrap_hook_fn = partial(hook_fn, indexes=indexes)
303
+
304
+ def real_hook_fn(grad):
305
+ return wrap_hook_fn(grad)
306
+ element.register_hook(real_hook_fn)
307
+
250
308
  def if_return_forward_new_output(self):
251
309
  return self._return_forward_new_output
252
310
 
@@ -383,3 +441,29 @@ class BaseDataProcessor:
383
441
  suffix + file_format)
384
442
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
385
443
  return dump_data_name, file_path
444
+
445
+ def analyze_element_to_all_none(self, element):
446
+ return self.recursive_apply_transform(element, lambda element, stack: None)
447
+
448
+ def analyze_debug_forward(self, variable, name_with_count):
449
+ self.current_api_or_module_name = name_with_count
450
+ self.api_data_category = Const.TENSOR
451
+ # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
452
+ data_info = self.analyze_element(variable)
453
+ return data_info
454
+
455
+ def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
456
+ def hook_fn(grad, indexes):
457
+ suffix = Const.SEP.join([str(index) for index in indexes])
458
+ self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
459
+ grad_data_info = self.analyze_element(grad)
460
+ self.save_name = None
461
+ full_index = [grad_name_with_count] + indexes
462
+ try:
463
+ self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
464
+ except (ValueError, IndexError) as e:
465
+ logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
466
+ f"skip current recording, detailed infomation: {e}")
467
+ return grad
468
+ wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
469
+ self.recursive_apply_transform(variable, wrap_register_hook_single_element)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.const import Const
17
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
17
18
 
18
19
 
19
20
  class DataProcessorFactory:
@@ -62,6 +63,7 @@ class DataProcessorFactory:
62
63
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
63
64
  cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
64
65
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
66
+ cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
65
67
  cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
66
68
  elif framework == Const.MS_FRAMEWORK:
67
69
  from msprobe.core.data_dump.data_processor.mindspore_processor import (
@@ -75,4 +77,5 @@ class DataProcessorFactory:
75
77
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
76
78
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
77
79
  cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
80
+ cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
78
81
  cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -23,7 +23,7 @@ import numpy as np
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
25
25
  ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
26
- from msprobe.core.common.file_utils import path_len_exceeds_limit
26
+ from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
27
27
  from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
28
  from msprobe.mindspore.common.log import logger
29
29
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
@@ -116,6 +116,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
116
116
  api_register.norm_inner_op_set_hook_func()
117
117
  return tensor_stat
118
118
 
119
+ @staticmethod
120
+ def is_hookable_element(element):
121
+ return hasattr(element, "register_hook") and callable(element.register_hook)
122
+
119
123
  @classmethod
120
124
  def get_special_types(cls):
121
125
  return super().get_special_types() + cls.mindspore_special_type
@@ -136,11 +140,13 @@ class MindsporeDataProcessor(BaseDataProcessor):
136
140
 
137
141
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
138
142
  if converted_numpy is not element:
139
- return self._analyze_numpy(converted_numpy, numpy_type)
143
+ return {"type": numpy_type, "value": converted_numpy}
140
144
  if isinstance(element, Number):
141
145
  return self.analyze_dtype_in_kwargs(element)
142
146
  if isinstance(element, ms.Tensor):
143
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
147
+ return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
148
+ if isinstance(element, np.ndarray):
149
+ return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
144
150
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
145
151
  return self._analyze_builtin(element)
146
152
  return {}
@@ -185,6 +191,13 @@ class TensorDataProcessor(MindsporeDataProcessor):
185
191
  else:
186
192
  save_tensor_as_npy(tensor, file_path)
187
193
  return single_arg
194
+
195
+ def _analyze_numpy(self, ndarray, suffix):
196
+ dump_data_name, file_path = self.get_save_file_path(suffix)
197
+ save_npy(ndarray, file_path)
198
+ ndarray_json = super()._analyze_numpy(ndarray, suffix)
199
+ ndarray_json.update({"data_name": dump_data_name})
200
+ return ndarray_json
188
201
 
189
202
 
190
203
  class OverflowCheckDataProcessor(MindsporeDataProcessor):
@@ -231,7 +244,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
231
244
  api_info_struct = super().analyze_backward(name, module, module_input_output)
232
245
  self.maybe_save_overflow_data()
233
246
  return api_info_struct if self.has_overflow else None
234
-
247
+
235
248
  def analyze_params(self, name, param_name, grad):
236
249
  self.has_overflow = False
237
250
  api_info_struct = super().analyze_params(name, param_name, grad)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +21,7 @@ from typing import List
21
21
  import numpy as np
22
22
  import torch
23
23
  from torch import distributed as dist
24
+ from torch.distributed.distributed_c10d import _get_default_group
24
25
 
25
26
  from msprobe.core.common.const import Const
26
27
  from msprobe.core.common.file_utils import path_len_exceeds_limit
@@ -40,7 +41,16 @@ except ImportError:
40
41
 
41
42
 
42
43
  class PytorchDataProcessor(BaseDataProcessor):
43
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
44
+ pytorch_special_type = (
45
+ torch.device,
46
+ torch.dtype,
47
+ torch.Size,
48
+ torch.Tensor,
49
+ torch.memory_format,
50
+ dist.ProcessGroup,
51
+ dist.P2POp,
52
+ dist.ReduceOp
53
+ )
44
54
  memory_format = {
45
55
  torch.contiguous_format: "contiguous_format",
46
56
  torch.channels_last: "channels_last",
@@ -168,6 +178,11 @@ class PytorchDataProcessor(BaseDataProcessor):
168
178
  def is_distributed_op(module):
169
179
  return getattr(module, "op_is_distributed", False)
170
180
 
181
+ @staticmethod
182
+ def is_hookable_element(element):
183
+ return (hasattr(element, "register_hook") and callable(element.register_hook)) and \
184
+ (hasattr(element, "requires_grad") and element.requires_grad)
185
+
171
186
  @staticmethod
172
187
  def _analyze_torch_size(arg):
173
188
  return {"type": "torch.Size", "value": list(arg)}
@@ -176,7 +191,6 @@ class PytorchDataProcessor(BaseDataProcessor):
176
191
  def _analyze_memory_format(arg):
177
192
  # 获取内存格式
178
193
  format_type = PytorchDataProcessor.memory_format.get(arg)
179
-
180
194
  return {"type": "torch.memory_format", "format": format_type}
181
195
 
182
196
  @staticmethod
@@ -188,9 +202,18 @@ class PytorchDataProcessor(BaseDataProcessor):
188
202
  group_id = PytorchDataProcessor.process_group_hash(arg)
189
203
  group_info.update({"group_id": group_id})
190
204
  except Exception as e:
191
- logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
205
+ logger.warning(f"Failed to get process group ranks info with error info: {e}.")
192
206
  return group_info
193
207
 
208
+ @staticmethod
209
+ def _analyze_reduce_op(arg):
210
+ op_type = None
211
+ try:
212
+ op_type = str(arg)
213
+ except Exception as e:
214
+ logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
215
+ return {"type": "torch.distributed.ReduceOp", "value": op_type}
216
+
194
217
  @classmethod
195
218
  def get_special_types(cls):
196
219
  return super().get_special_types() + cls.pytorch_special_type
@@ -204,11 +227,17 @@ class PytorchDataProcessor(BaseDataProcessor):
204
227
  return self._analyze_memory_format(element)
205
228
  if isinstance(element, dist.ProcessGroup):
206
229
  return self._analyze_process_group(element)
230
+ if isinstance(element, dist.P2POp):
231
+ return self._analyze_p2pop(element)
232
+ if isinstance(element, dist.ReduceOp):
233
+ return self._analyze_reduce_op(element)
207
234
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
208
235
  if converted_numpy is not element:
209
- return self._analyze_numpy(converted_numpy, numpy_type)
236
+ return {"type": numpy_type, "value": converted_numpy}
210
237
  if isinstance(element, torch.Tensor):
211
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
238
+ return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
239
+ if isinstance(element, np.ndarray):
240
+ return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
212
241
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
213
242
  return self._analyze_builtin(element)
214
243
  return {}
@@ -218,6 +247,21 @@ class PytorchDataProcessor(BaseDataProcessor):
218
247
  module_input_output.update_output_with_args_and_kwargs()
219
248
  return super().analyze_forward_output(name, module, module_input_output)
220
249
 
250
+ def _analyze_p2pop(self, arg):
251
+ p2pop_info = {"class_type": "torch.distributed.P2POp"}
252
+ try:
253
+ tensor_info = self._analyze_tensor(arg.tensor, [])
254
+ p2pop_info.update({"tensor": tensor_info})
255
+ p2pop_info.update({"op": arg.op.__name__})
256
+ p2pop_info.update({"peer": arg.peer})
257
+ p2pop_info.update({"tag": arg.tag})
258
+ group_id = PytorchDataProcessor.process_group_hash(
259
+ arg.group) if arg.group else PytorchDataProcessor.process_group_hash(_get_default_group())
260
+ p2pop_info.update({"group_id": group_id})
261
+ except Exception as e:
262
+ logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
263
+ return p2pop_info
264
+
221
265
  def _analyze_tensor(self, tensor, suffix):
222
266
  tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
223
267
  tensor_json = {}
@@ -267,6 +311,13 @@ class TensorDataProcessor(PytorchDataProcessor):
267
311
  saved_tensor = tensor.clone().contiguous().detach()
268
312
  save_pt(saved_tensor, file_path)
269
313
  return single_arg
314
+
315
+ def _analyze_numpy(self, ndarray, suffix):
316
+ dump_data_name, file_path = self.get_save_file_path(suffix)
317
+ save_pt(torch.tensor(ndarray), file_path)
318
+ ndarray_json = super()._analyze_numpy(ndarray, suffix)
319
+ ndarray_json.update({"data_name": dump_data_name})
320
+ return ndarray_json
270
321
 
271
322
 
272
323
  class OverflowCheckDataProcessor(PytorchDataProcessor):
@@ -319,7 +370,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
319
370
  api_info_struct = super().analyze_backward(name, module, module_input_output)
320
371
  self.handle_overflow()
321
372
  return api_info_struct if self.has_overflow else None
322
-
373
+
323
374
  def analyze_params(self, name, param_name, grad):
324
375
  self.has_overflow = False
325
376
  self._is_support_inf_nan()
@@ -15,6 +15,7 @@
15
15
 
16
16
  import csv
17
17
  import os
18
+ import copy
18
19
  import numpy as np
19
20
 
20
21
  from msprobe.core.common.const import Const, FileCheckConst
@@ -31,10 +32,12 @@ class DataWriter:
31
32
  self.construct_file_path = None
32
33
  self.free_benchmark_file_path = None
33
34
  self.dump_tensor_data_dir = None
35
+ self.debug_file_path = None
34
36
  self.flush_size = 1000
35
37
  self.cache_data = {}
36
38
  self.cache_stack = {}
37
39
  self.cache_construct = {}
40
+ self.cache_debug = {}
38
41
 
39
42
  @staticmethod
40
43
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -57,6 +60,13 @@ class DataWriter:
57
60
  self.cache_construct = {}
58
61
 
59
62
  def initialize_json_file(self, **kwargs):
63
+ if self.debug_file_path and not self.cache_debug:
64
+ # debug level case only create debug.json
65
+ debug_dict = copy.deepcopy(kwargs)
66
+ debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
67
+ self.cache_debug = debug_dict
68
+ save_json(self.debug_file_path, self.cache_debug, indent=1)
69
+ return
60
70
  if not self.cache_data:
61
71
  kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
62
72
  self.cache_data = kwargs
@@ -66,13 +76,13 @@ class DataWriter:
66
76
  if not self.cache_construct:
67
77
  save_json(self.construct_file_path, self.cache_construct, indent=1)
68
78
 
69
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
70
- free_benchmark_file_path):
71
- self.dump_file_path = dump_file_path
72
- self.stack_file_path = stack_file_path
73
- self.construct_file_path = construct_file_path
74
- self.dump_tensor_data_dir = dump_data_dir
75
- self.free_benchmark_file_path = free_benchmark_file_path
79
+ def update_dump_paths(self, dump_path_aggregation):
80
+ self.dump_file_path = dump_path_aggregation.dump_file_path
81
+ self.stack_file_path = dump_path_aggregation.stack_file_path
82
+ self.construct_file_path = dump_path_aggregation.construct_file_path
83
+ self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
84
+ self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
85
+ self.debug_file_path = dump_path_aggregation.debug_file_path
76
86
 
77
87
  def flush_data_periodically(self):
78
88
  dump_data = self.cache_data.get(Const.DATA)
@@ -100,6 +110,9 @@ class DataWriter:
100
110
  def update_construct(self, new_data):
101
111
  self.cache_construct.update(new_data)
102
112
 
113
+ def update_debug(self, new_data):
114
+ self.cache_debug['data'].update(new_data)
115
+
103
116
  def write_data_json(self, file_path):
104
117
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
105
118
  save_json(file_path, self.cache_data, indent=1)
@@ -110,6 +123,9 @@ class DataWriter:
110
123
  def write_construct_info_json(self, file_path):
111
124
  save_json(file_path, self.cache_construct, indent=1)
112
125
 
126
+ def write_debug_info_json(self, file_path):
127
+ save_json(file_path, self.cache_debug, indent=1)
128
+
113
129
  def write_json(self):
114
130
  if self.cache_data:
115
131
  self.write_data_json(self.dump_file_path)
@@ -117,6 +133,8 @@ class DataWriter:
117
133
  self.write_stack_info_json(self.stack_file_path)
118
134
  if self.cache_construct:
119
135
  self.write_construct_info_json(self.construct_file_path)
136
+ if self.cache_debug:
137
+ self.write_debug_info_json(self.debug_file_path)
120
138
 
121
139
  def fill_stack_tensor_data(self):
122
140
  self.process_stat_data_recursive(self.cache_data)
@@ -135,7 +153,7 @@ class DataWriter:
135
153
  if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
136
154
  tensor_stat_data = tensor_stat_data.cpu()
137
155
  for index, stat in zip(tensor_stat_index, tensor_stat_data):
138
- data.update({index, stat.item()})
156
+ data.update({index: stat.item()})
139
157
  del data["tensor_stat"]
140
158
  else:
141
159
  for key in data.keys():