mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import re
17
17
  import torch
18
18
 
19
19
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
- from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE
20
+ from msprobe.pytorch.monitor.utils import get_nan_tensor
21
21
 
22
22
 
23
23
  def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
@@ -147,13 +147,13 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
147
  """
148
148
  :param ops: ["op1", "op2"]
149
149
  :param tag2tensor: {
150
- '0:fc_0/input': torch.randn([3, 4]),
151
- '0:fc_0/output': torch.randn([3, 3])
150
+ '0:fc.input:0/actv': torch.randn([3, 4]),
151
+ '0:fc.output:0/actv': torch.randn([3, 3])
152
152
  }
153
153
  :param eps: float 1e-8
154
154
  :param out_dict:{
155
- '0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
- '0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
155
+ '0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
+ '0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
157
157
  }
158
158
  :return: out_dict
159
159
  """
@@ -164,7 +164,7 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
164
164
  out_dict[tag] = {}
165
165
  if not torch.is_tensor(tensor):
166
166
  # Non-tensor in/output filled with nan.
167
- out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops})
167
+ out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
168
168
  continue
169
169
  for metric_name in ops:
170
170
  fun_metric = config_metric_registry.get(metric_name)
@@ -23,16 +23,10 @@ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
23
23
 
24
24
 
25
25
  class OptimizerMon(object):
26
- wrapped_optimizer = None
27
-
28
26
  def __init__(self) -> None:
29
27
  self.fp16_to_fp32_param = {}
30
28
  self.is_stage3 = False
31
29
 
32
- @classmethod
33
- def set_wrapped_optimizer(cls, wrapped_optimizer):
34
- cls.wrapped_optimizer = wrapped_optimizer
35
-
36
30
  def fetch_mv(self, monitor, torch_opt, params2name):
37
31
  pass
38
32
 
@@ -82,7 +76,6 @@ class OptimizerMon(object):
82
76
  ratio_dict = defaultdict()
83
77
  param2name = defaultdict()
84
78
  fp32_partitioned_groups_flat_grad = defaultdict()
85
- mix_prec_opt = OptimizerMon.wrapped_optimizer
86
79
  partition_id = dist.get_rank()
87
80
 
88
81
  def get_flatten_grad(self, optimizer, group_idx):
@@ -101,7 +94,7 @@ class OptimizerMon(object):
101
94
  return fp32_partitioned_groups_flat[group_idx].grad
102
95
 
103
96
  for group_idx in range(len(fp32_partitioned_groups_flat)):
104
- fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
97
+ fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
105
98
 
106
99
  for name in params2name.values():
107
100
  start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
@@ -110,9 +103,9 @@ class OptimizerMon(object):
110
103
  fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
111
104
  fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
112
105
  param2name[fp32_param] = name
113
- if not mix_prec_opt.state:
106
+ if not torch_opt.state:
114
107
  continue
115
- state_param = list(mix_prec_opt.state.values())[group_idx]
108
+ state_param = list(torch_opt.state.values())[group_idx]
116
109
  exp_avg = state_param.get("exp_avg", None)
117
110
  exp_avg_sq = state_param.get("exp_avg_sq", None)
118
111
  if exp_avg is None or exp_avg_sq is None:
@@ -150,36 +143,33 @@ class MixPrecisionOptimizerMon(OptimizerMon):
150
143
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
151
144
  """
152
145
 
153
- def map_fp16_tp_fp32_param(self, mix_prec_opt):
154
- for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
146
+ def map_fp16_tp_fp32_param(self, torch_opt):
147
+ for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
155
148
  for fp16_param, fp32_param in zip(fp16_group, fp32_group):
156
149
  self.fp16_to_fp32_param[fp16_param] = fp32_param
157
150
 
158
151
  def fetch_mv(self, monitor, torch_opt, params2name):
159
- mix_prec_opt = self.wrapped_optimizer
160
-
161
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
162
- self.map_fp16_tp_fp32_param(mix_prec_opt)
152
+ if not self.fp16_to_fp32_param and torch_opt is not None:
153
+ self.map_fp16_tp_fp32_param(torch_opt)
163
154
 
164
155
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
165
156
 
166
157
 
167
158
  class MegatronDistributedOptimizerMon(OptimizerMon):
168
- def map_fp16_tp_fp32_param(self, mix_prec_opt):
169
- if not (hasattr(mix_prec_opt, "model_float16_groups") and
170
- hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
159
+ def map_fp16_tp_fp32_param(self, torch_opt):
160
+ if not (hasattr(torch_opt, "model_float16_groups") and
161
+ hasattr(torch_opt, "shard_fp32_from_float16_groups")):
171
162
  raise Exception(
172
163
  "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
173
164
  "if not, please check megatron-lm version")
174
- for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
175
- mix_prec_opt.shard_fp32_from_float16_groups):
165
+ for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups,
166
+ torch_opt.shard_fp32_from_float16_groups):
176
167
  for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
177
168
  self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
178
169
 
179
170
  def fetch_mv(self, monitor, torch_opt, params2name):
180
- mix_prec_opt = self.wrapped_optimizer
181
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
182
- self.map_fp16_tp_fp32_param(mix_prec_opt)
171
+ if not self.fp16_to_fp32_param and torch_opt is not None:
172
+ self.map_fp16_tp_fp32_param(torch_opt)
183
173
 
184
174
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
185
175
 
@@ -191,30 +181,26 @@ class MegatronFP32OptimizerMon(OptimizerMon):
191
181
 
192
182
  class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
193
183
  def fetch_mv(self, monitor, torch_opt, params2name):
194
- mix_prec_opt = self.wrapped_optimizer
195
-
196
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
197
- for opt in mix_prec_opt.chained_optimizers:
184
+ if not self.fp16_to_fp32_param and torch_opt is not None:
185
+ for opt in torch_opt.chained_optimizers:
198
186
  self.map_fp16_tp_fp32_param(opt)
199
187
 
200
188
  if not isinstance(torch_opt, torch.optim.Optimizer):
201
189
  torch_opt.state = {}
202
- for opt in mix_prec_opt.chained_optimizers:
190
+ for opt in torch_opt.chained_optimizers:
203
191
  torch_opt.state.update(opt.optimizer.state)
204
192
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
205
193
 
206
194
 
207
195
  class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
208
196
  def fetch_mv(self, monitor, torch_opt, params2name):
209
- mix_prec_opt = self.wrapped_optimizer
210
-
211
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
212
- for opt in mix_prec_opt.chained_optimizers:
197
+ if not self.fp16_to_fp32_param and torch_opt is not None:
198
+ for opt in torch_opt.chained_optimizers:
213
199
  self.map_fp16_tp_fp32_param(opt)
214
200
 
215
201
  if not isinstance(torch_opt, torch.optim.Optimizer):
216
202
  torch_opt.state = {}
217
- for opt in mix_prec_opt.chained_optimizers:
203
+ for opt in torch_opt.chained_optimizers:
218
204
  torch_opt.state.update(opt.optimizer.state)
219
205
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
220
206
 
@@ -225,9 +211,8 @@ class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
225
211
 
226
212
 
227
213
  class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
228
- def get_param_index(self, params2name, name2index):
229
- mix_prec_opt = OptimizerMon.wrapped_optimizer
230
- fp16_groups = mix_prec_opt.fp16_partitioned_groups
214
+ def get_param_index(self, params2name, name2index, torch_opt):
215
+ fp16_groups = torch_opt.fp16_partitioned_groups
231
216
  name2indices = defaultdict()
232
217
  index_length = defaultdict()
233
218
  index = 0
@@ -246,13 +231,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
246
231
 
247
232
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
248
233
  self.is_stage3 = True
249
- mix_prec_opt = OptimizerMon.wrapped_optimizer
250
- fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
234
+ fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
251
235
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
252
236
 
253
237
 
254
238
  class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
255
-
256
239
  @staticmethod
257
240
  def get_group_index(fp32_length, world_size, index):
258
241
  for i in range(len(fp32_length) - 1):
@@ -265,12 +248,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
265
248
  return sub_interval_start, min(sub_index, world_size - 1)
266
249
  return fp32_length[-1], 0
267
250
 
268
- def get_param_index(self, params2name, name2index):
269
- mix_prec_opt = OptimizerMon.wrapped_optimizer
270
- padding = mix_prec_opt.groups_padding
251
+ def get_param_index(self, params2name, name2index, torch_opt):
252
+ padding = torch_opt.groups_padding
271
253
  world_size = dist.get_world_size()
272
254
  fp32_length = [0]
273
- for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
255
+ for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
274
256
  fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
275
257
 
276
258
  bf16_groups = []
@@ -278,7 +260,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
278
260
  index_length = defaultdict()
279
261
  index = 0
280
262
  idx = 0
281
- for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
263
+ for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
282
264
  bf16_groups.extend(bf16_group)
283
265
  for param in bf16_group:
284
266
  param_length = len(param.flatten())
@@ -286,7 +268,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
286
268
  index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
287
269
  index += param_length
288
270
  idx += 1
289
- group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
271
+ group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
290
272
  for _, name in params2name.items():
291
273
  name_index = name2index[name]
292
274
  start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
@@ -300,8 +282,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
300
282
  return name2indices
301
283
 
302
284
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
303
- mix_prec_opt = OptimizerMon.wrapped_optimizer
304
- fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
285
+ fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
305
286
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
306
287
 
307
288
 
@@ -312,22 +293,23 @@ class DummyOptimizerMon(OptimizerMon):
312
293
 
313
294
  class OptimizerMonFactory:
314
295
  _optimizer_mon_map = {
315
- "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
316
- "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
317
- "Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
- "Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
319
- "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
320
- "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
321
- "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
296
+ "FP32Optimizer": MegatronFP32OptimizerMon,
297
+ "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
298
+ "DistributedOptimizer": MegatronDistributedOptimizerMon,
299
+ "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
300
+ "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
301
+ "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
302
+ "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
322
303
  "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
323
- "unknown": DummyOptimizerMon
304
+ "Adam": DummyOptimizerMon
324
305
  }
325
306
 
326
307
  @staticmethod
327
- def create_optimizer_mon(opt_ty: str):
328
- if not opt_ty:
329
- return DummyOptimizerMon()
330
- optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
331
- if not optimizer_mon_class:
332
- raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
333
- return optimizer_mon_class()
308
+ def create_optimizer_mon(optimizer):
309
+ # auto replace opt_ty
310
+ optimizer_class = optimizer.__class__.__name__
311
+ if optimizer_class == "ChainedOptimizer":
312
+ optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
313
+
314
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
315
+ return optimizer_mon_class(), optimizer_class
@@ -36,7 +36,7 @@ except ImportError:
36
36
  if torch.cuda.is_available():
37
37
  device = "cuda"
38
38
 
39
- NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
39
+ NAN_TENSOR_ON_DEVICE = None
40
40
  FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
41
41
  FILE_NAME_MAX_LENGTH = 255
42
42
  DIRECTORY_MAX_LENGTH = 4096
@@ -57,6 +57,13 @@ def get_output_base_dir():
57
57
  return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
58
58
 
59
59
 
60
+ def get_nan_tensor():
61
+ global NAN_TENSOR_ON_DEVICE
62
+ if not NAN_TENSOR_ON_DEVICE:
63
+ NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
64
+ return NAN_TENSOR_ON_DEVICE
65
+
66
+
60
67
  def filter_special_chars(func):
61
68
  @wraps(func)
62
69
  def func_level(msg):
@@ -82,48 +89,6 @@ def get_param_struct(param):
82
89
  return res
83
90
 
84
91
 
85
- def is_recomputation():
86
- """Check if the current operation is in the re-computation phase.
87
-
88
- This function inspects the current call stack to indicate whether the current operation is in the
89
- re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
90
- megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
91
- mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
92
- file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
93
- 'torch/_tensor.py' file.
94
-
95
- Returns:
96
- bool: True if in the re-computation phase, False otherwise.
97
- """
98
- backward_function_indices = []
99
- call_stack = inspect.stack()
100
-
101
- # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
102
- for frame_info in call_stack:
103
- if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
104
- del call_stack
105
- return True
106
-
107
- # Identify indices in the call stack where the specific function is being executed
108
- for idx, frame_info in enumerate(call_stack):
109
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
110
- backward_function_indices.append(idx)
111
-
112
- # Check if the execution is within 'torch/autograd/function.py' file
113
- for idx in backward_function_indices:
114
- # The Megatron and MindSpeed L0&L1 scenes
115
- if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
116
- del call_stack
117
- return True
118
- # The latest MindSpeed L2 and ModelLink scenes
119
- if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
120
- del call_stack
121
- return True
122
-
123
- del call_stack
124
- return False
125
-
126
-
127
92
  def validate_ops(ops):
128
93
  if not isinstance(ops, list):
129
94
  raise TypeError("ops should be a list")
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -303,28 +303,25 @@ class GradToolConfig(BaseConfig):
303
303
  check_bounds(self.bounds)
304
304
 
305
305
 
306
+ class StructureConfig(BaseConfig):
307
+ def __init__(self, json_config):
308
+ super().__init__(json_config)
309
+
310
+
311
+ TaskDict = {
312
+ Const.TENSOR: TensorConfig,
313
+ Const.STATISTICS: StatisticsConfig,
314
+ Const.OVERFLOW_CHECK: OverflowCheckConfig,
315
+ Const.FREE_BENCHMARK: FreeBenchmarkCheckConfig,
316
+ Const.RUN_UT: RunUTConfig,
317
+ Const.GRAD_PROBE: GradToolConfig,
318
+ Const.STRUCTURE: StructureConfig
319
+ }
320
+
321
+
306
322
  def parse_task_config(task, json_config):
307
- default_dic = {}
308
- if task == Const.TENSOR:
309
- config_dic = json_config.get(Const.TENSOR, default_dic)
310
- return TensorConfig(config_dic)
311
- elif task == Const.STATISTICS:
312
- config_dic = json_config.get(Const.STATISTICS, default_dic)
313
- return StatisticsConfig(config_dic)
314
- elif task == Const.OVERFLOW_CHECK:
315
- config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
316
- return OverflowCheckConfig(config_dic)
317
- elif task == Const.FREE_BENCHMARK:
318
- config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
319
- return FreeBenchmarkCheckConfig(config_dic)
320
- elif task == Const.RUN_UT:
321
- config_dic = json_config.get(Const.RUN_UT, default_dic)
322
- return RunUTConfig(config_dic)
323
- elif task == Const.GRAD_PROBE:
324
- config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
325
- return GradToolConfig(config_dic)
326
- else:
327
- return StatisticsConfig(default_dic)
323
+ task_map = json_config.get(task, dict())
324
+ return TaskDict.get(task)(task_map)
328
325
 
329
326
 
330
327
  def parse_json_config(json_file_path, task):
@@ -15,19 +15,19 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from collections import namedtuple
18
+ from collections import namedtuple, defaultdict
19
19
 
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
22
  from msprobe.core.common.exceptions import DistributedNotInitializedError
23
23
  from msprobe.core.common.file_utils import create_directory
24
- from msprobe.core.common.utils import print_tools_ends_info
24
+ from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
27
  from msprobe.core.data_dump.scope import BaseScope
28
28
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
29
29
  from msprobe.pytorch.common.log import logger
30
- from msprobe.pytorch.common.utils import get_rank_if_initialized
30
+ from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
31
  from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
32
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
33
  from msprobe.pytorch.hook_module.api_registry import api_register
@@ -56,13 +56,16 @@ class Service:
56
56
  self.should_stop_service = False
57
57
  self.attl = None
58
58
  self.params_grad_info = {}
59
+ self.hook_handle_dict = {}
59
60
  # 提前注册,确保注册尽可能多的API hook
60
61
  self.register_api_hook()
62
+ self.init_for_debug_level()
61
63
 
62
64
  def build_hook(self, module_type, name):
63
65
  def pre_hook(api_or_module_name, module, args, kwargs):
64
66
  if not self.should_execute_hook(module_type, module, True):
65
67
  return args, kwargs
68
+ is_recompute = is_recomputation()
66
69
 
67
70
  self.inner_switch = True
68
71
  if module_type == BaseScope.Module_Type_Module:
@@ -77,7 +80,13 @@ class Service:
77
80
  return None, None
78
81
  if self.data_collector:
79
82
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
80
- self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output)
83
+ self.data_collector.forward_input_data_collect(
84
+ api_or_module_name,
85
+ module,
86
+ pid,
87
+ module_input_output,
88
+ is_recompute
89
+ )
81
90
 
82
91
  self.inner_switch = False
83
92
  return args, kwargs
@@ -101,7 +110,12 @@ class Service:
101
110
  if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
102
111
  for param_name, param in params_dict.items():
103
112
  if param.requires_grad:
104
- param.register_hook(grad_hook(module, ori_name, param_name))
113
+ name = ori_name + Const.SEP + param_name
114
+ old_handle = self.hook_handle_dict.get(name)
115
+ if old_handle and hasattr(old_handle, "remove"):
116
+ old_handle.remove()
117
+ handle = param.register_hook(grad_hook(module, ori_name, param_name))
118
+ self.hook_handle_dict[name] = handle
105
119
 
106
120
  def init_params_grad_info(module, params_dict):
107
121
  '''
@@ -125,6 +139,7 @@ class Service:
125
139
  def forward_hook(api_or_module_name, module, args, kwargs, output):
126
140
  if not self.should_execute_hook(module_type, module, True):
127
141
  return None
142
+ is_recompute = is_recomputation()
128
143
 
129
144
  self.inner_switch = True
130
145
  if self.config.online_run_ut:
@@ -147,10 +162,15 @@ class Service:
147
162
  if module_type == BaseScope.Module_Type_Module:
148
163
  api_or_module_name = module.mindstudio_reserved_name[-1]
149
164
  self.data_collector.update_api_or_module_name(api_or_module_name)
150
- params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)}
151
- setattr(module_input_output, Const.PARAMS, params_dict)
165
+ params_dict = {}
166
+ if self.config.task != Const.STRUCTURE:
167
+ params_dict = {
168
+ key.split(Const.SEP)[-1]: value
169
+ for key, value in module.named_parameters(recurse=False)
170
+ }
171
+ setattr(module_input_output, Const.PARAMS, params_dict)
152
172
  # 判断是否需要注册参数hook
153
- if not hasattr(module, 'params_grad_name') and params_dict:
173
+ if params_dict:
154
174
  ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
155
175
  grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
156
176
  # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
@@ -160,7 +180,8 @@ class Service:
160
180
  api_or_module_name,
161
181
  module,
162
182
  pid,
163
- module_input_output
183
+ module_input_output,
184
+ is_recompute
164
185
  )
165
186
  init_params_grad_info(module, params_dict)
166
187
  else:
@@ -169,7 +190,8 @@ class Service:
169
190
  api_or_module_name,
170
191
  module,
171
192
  pid,
172
- module_input_output
193
+ module_input_output,
194
+ is_recompute
173
195
  )
174
196
 
175
197
  if self.data_collector.if_return_forward_new_output():
@@ -185,6 +207,7 @@ class Service:
185
207
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
186
208
  if not self.should_execute_hook(module_type, module, False):
187
209
  return
210
+ is_recompute = is_recomputation()
188
211
 
189
212
  self.inner_switch = True
190
213
  if module_type == BaseScope.Module_Type_Module:
@@ -198,7 +221,13 @@ class Service:
198
221
  if self.data_collector:
199
222
  # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
200
223
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
201
- self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
224
+ self.data_collector.backward_data_collect(
225
+ api_or_module_name,
226
+ module,
227
+ pid,
228
+ module_input_output,
229
+ is_recompute
230
+ )
202
231
  self.inner_switch = False
203
232
 
204
233
  pid = os.getpid()
@@ -217,6 +246,8 @@ class Service:
217
246
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
218
247
 
219
248
  def start(self, model):
249
+ if self.config.level == Const.LEVEL_DEBUG:
250
+ return
220
251
  if self.need_stop_service():
221
252
  return
222
253
 
@@ -231,6 +262,8 @@ class Service:
231
262
  if self.config.rank and self.current_rank not in self.config.rank:
232
263
  return
233
264
  self.register_module_hook()
265
+ if self.config.level == Const.LEVEL_MIX:
266
+ register_optimizer_hook(self.data_collector)
234
267
  self.first_start = False
235
268
  if self.config.online_run_ut and torch_version_above_or_equal_2:
236
269
  run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
@@ -241,6 +274,8 @@ class Service:
241
274
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
242
275
 
243
276
  def stop(self):
277
+ if self.config.level == Const.LEVEL_DEBUG:
278
+ return
244
279
  if self.should_stop_service:
245
280
  return
246
281
  if self.config.step and self.current_iter not in self.config.step:
@@ -255,15 +290,19 @@ class Service:
255
290
  return
256
291
  if self.config.async_dump:
257
292
  self.data_collector.fill_stack_tensor_data()
258
- self.data_collector.data_processor.dump_async_data()
293
+ if self.config.task == Const.TENSOR:
294
+ self.data_collector.data_processor.dump_async_data()
259
295
  self.data_collector.write_json()
260
296
 
261
297
  def step(self):
298
+ if self.config.level == Const.LEVEL_DEBUG:
299
+ return
262
300
  if self.should_stop_service:
263
301
  return
264
302
  if self.config.async_dump:
265
303
  self.data_collector.fill_stack_tensor_data()
266
- self.data_collector.data_processor.dump_async_data()
304
+ if self.config.task == Const.TENSOR:
305
+ self.data_collector.data_processor.dump_async_data()
267
306
  self.data_collector.write_json()
268
307
  self.current_iter += 1
269
308
  self.data_collector.update_iter(self.current_iter)
@@ -319,13 +358,13 @@ class Service:
319
358
  else:
320
359
  dump_data_dir = None
321
360
 
322
- dump_file_path = os.path.join(dump_dir, "dump.json")
323
- stack_file_path = os.path.join(dump_dir, "stack.json")
324
- construct_file_path = os.path.join(dump_dir, "construct.json")
325
- free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
326
- self.data_collector.update_dump_paths(
327
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
328
- )
361
+ dump_path_aggregation = DumpPathAggregation()
362
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
363
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
364
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
365
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
366
+ dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
367
+ self.data_collector.update_dump_paths(dump_path_aggregation)
329
368
  self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
330
369
 
331
370
  def register_api_hook(self):
@@ -337,9 +376,6 @@ class Service:
337
376
  )
338
377
  api_register.api_modularity()
339
378
 
340
- if self.config.level == Const.LEVEL_MIX:
341
- register_optimizer_hook(self.data_collector)
342
-
343
379
  def register_module_hook(self):
344
380
  if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
345
381
  logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
@@ -379,7 +415,7 @@ class Service:
379
415
  def reset_status(self):
380
416
  ModuleProcesser.reset_module_stats()
381
417
  HOOKModule.reset_module_stats()
382
- self.data_collector.data_writer.reset_cache()
418
+ self.data_collector.reset_status()
383
419
  self.params_grad_info.clear()
384
420
 
385
421
  if self.config.level == Const.LEVEL_L2:
@@ -389,3 +425,46 @@ class Service:
389
425
  return
390
426
  if self.config.rank and self.current_rank not in self.config.rank:
391
427
  return
428
+
429
+ def init_for_debug_level(self):
430
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
431
+ return
432
+ try:
433
+ self.current_rank = get_rank_if_initialized()
434
+ except DistributedNotInitializedError:
435
+ self.current_rank = None
436
+
437
+ # dir: dump_path -- rank{} -- debug.json
438
+ self.dump_iter_dir = self.config.dump_path
439
+ cur_rank = self.current_rank if self.current_rank is not None else ''
440
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
441
+ create_directory(dump_dir)
442
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
443
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
444
+ create_directory(dump_data_dir)
445
+ else:
446
+ dump_data_dir = None
447
+
448
+ dump_path_aggregation = DumpPathAggregation()
449
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
450
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
451
+ self.data_collector.update_dump_paths(dump_path_aggregation)
452
+ self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
453
+
454
+ self.debug_variable_counter = defaultdict(int)
455
+
456
+ def save(self, variable, name, save_backward):
457
+ if self.config.level != Const.LEVEL_DEBUG:
458
+ return
459
+ count = self.debug_variable_counter[name]
460
+ self.debug_variable_counter[name] += 1
461
+
462
+ name_with_count = f"{name}.{count}"
463
+ grad_name_with_count = f"{name}_grad.{count}"
464
+
465
+ # forward save
466
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
467
+
468
+ # backward save
469
+ if save_backward:
470
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)