mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,9 @@ import re
16
16
 
17
17
  import torch
18
18
 
19
+ from msprobe.pytorch.common.utils import is_float8_tensor
19
20
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
- from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE
21
+ from msprobe.pytorch.monitor.utils import get_nan_tensor
21
22
 
22
23
 
23
24
  def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
@@ -147,13 +148,13 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
148
  """
148
149
  :param ops: ["op1", "op2"]
149
150
  :param tag2tensor: {
150
- '0:fc_0/input': torch.randn([3, 4]),
151
- '0:fc_0/output': torch.randn([3, 3])
151
+ '0:fc.input:0/actv': torch.randn([3, 4]),
152
+ '0:fc.output:0/actv': torch.randn([3, 3])
152
153
  }
153
154
  :param eps: float 1e-8
154
155
  :param out_dict:{
155
- '0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
- '0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
156
+ '0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
157
+ '0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
157
158
  }
158
159
  :return: out_dict
159
160
  """
@@ -164,8 +165,10 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
164
165
  out_dict[tag] = {}
165
166
  if not torch.is_tensor(tensor):
166
167
  # Non-tensor in/output filled with nan.
167
- out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops})
168
+ out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
168
169
  continue
170
+ if is_float8_tensor(tensor):
171
+ tensor = tensor.float()
169
172
  for metric_name in ops:
170
173
  fun_metric = config_metric_registry.get(metric_name)
171
174
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
@@ -23,16 +23,10 @@ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
23
23
 
24
24
 
25
25
  class OptimizerMon(object):
26
- wrapped_optimizer = None
27
-
28
26
  def __init__(self) -> None:
29
27
  self.fp16_to_fp32_param = {}
30
28
  self.is_stage3 = False
31
29
 
32
- @classmethod
33
- def set_wrapped_optimizer(cls, wrapped_optimizer):
34
- cls.wrapped_optimizer = wrapped_optimizer
35
-
36
30
  def fetch_mv(self, monitor, torch_opt, params2name):
37
31
  pass
38
32
 
@@ -82,7 +76,6 @@ class OptimizerMon(object):
82
76
  ratio_dict = defaultdict()
83
77
  param2name = defaultdict()
84
78
  fp32_partitioned_groups_flat_grad = defaultdict()
85
- mix_prec_opt = OptimizerMon.wrapped_optimizer
86
79
  partition_id = dist.get_rank()
87
80
 
88
81
  def get_flatten_grad(self, optimizer, group_idx):
@@ -101,7 +94,7 @@ class OptimizerMon(object):
101
94
  return fp32_partitioned_groups_flat[group_idx].grad
102
95
 
103
96
  for group_idx in range(len(fp32_partitioned_groups_flat)):
104
- fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
97
+ fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
105
98
 
106
99
  for name in params2name.values():
107
100
  start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
@@ -110,9 +103,9 @@ class OptimizerMon(object):
110
103
  fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
111
104
  fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
112
105
  param2name[fp32_param] = name
113
- if not mix_prec_opt.state:
106
+ if not torch_opt.state:
114
107
  continue
115
- state_param = list(mix_prec_opt.state.values())[group_idx]
108
+ state_param = list(torch_opt.state.values())[group_idx]
116
109
  exp_avg = state_param.get("exp_avg", None)
117
110
  exp_avg_sq = state_param.get("exp_avg_sq", None)
118
111
  if exp_avg is None or exp_avg_sq is None:
@@ -150,36 +143,33 @@ class MixPrecisionOptimizerMon(OptimizerMon):
150
143
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
151
144
  """
152
145
 
153
- def map_fp16_tp_fp32_param(self, mix_prec_opt):
154
- for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
146
+ def map_fp16_tp_fp32_param(self, torch_opt):
147
+ for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
155
148
  for fp16_param, fp32_param in zip(fp16_group, fp32_group):
156
149
  self.fp16_to_fp32_param[fp16_param] = fp32_param
157
150
 
158
151
  def fetch_mv(self, monitor, torch_opt, params2name):
159
- mix_prec_opt = self.wrapped_optimizer
160
-
161
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
162
- self.map_fp16_tp_fp32_param(mix_prec_opt)
152
+ if not self.fp16_to_fp32_param and torch_opt is not None:
153
+ self.map_fp16_tp_fp32_param(torch_opt)
163
154
 
164
155
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
165
156
 
166
157
 
167
158
  class MegatronDistributedOptimizerMon(OptimizerMon):
168
- def map_fp16_tp_fp32_param(self, mix_prec_opt):
169
- if not (hasattr(mix_prec_opt, "model_float16_groups") and
170
- hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
159
+ def map_fp16_tp_fp32_param(self, torch_opt):
160
+ if not (hasattr(torch_opt, "model_float16_groups") and
161
+ hasattr(torch_opt, "shard_fp32_from_float16_groups")):
171
162
  raise Exception(
172
163
  "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
173
164
  "if not, please check megatron-lm version")
174
- for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
175
- mix_prec_opt.shard_fp32_from_float16_groups):
165
+ for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups,
166
+ torch_opt.shard_fp32_from_float16_groups):
176
167
  for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
177
168
  self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
178
169
 
179
170
  def fetch_mv(self, monitor, torch_opt, params2name):
180
- mix_prec_opt = self.wrapped_optimizer
181
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
182
- self.map_fp16_tp_fp32_param(mix_prec_opt)
171
+ if not self.fp16_to_fp32_param and torch_opt is not None:
172
+ self.map_fp16_tp_fp32_param(torch_opt)
183
173
 
184
174
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
185
175
 
@@ -191,43 +181,89 @@ class MegatronFP32OptimizerMon(OptimizerMon):
191
181
 
192
182
  class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
193
183
  def fetch_mv(self, monitor, torch_opt, params2name):
194
- mix_prec_opt = self.wrapped_optimizer
195
-
196
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
197
- for opt in mix_prec_opt.chained_optimizers:
184
+ if not self.fp16_to_fp32_param and torch_opt is not None:
185
+ for opt in torch_opt.chained_optimizers:
198
186
  self.map_fp16_tp_fp32_param(opt)
199
187
 
200
- if not isinstance(torch_opt, torch.optim.Optimizer):
188
+ if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
201
189
  torch_opt.state = {}
202
- for opt in mix_prec_opt.chained_optimizers:
190
+ for opt in torch_opt.chained_optimizers:
203
191
  torch_opt.state.update(opt.optimizer.state)
204
192
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
205
193
 
206
194
 
207
195
  class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
208
196
  def fetch_mv(self, monitor, torch_opt, params2name):
209
- mix_prec_opt = self.wrapped_optimizer
210
-
211
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
212
- for opt in mix_prec_opt.chained_optimizers:
197
+ if not self.fp16_to_fp32_param and torch_opt is not None:
198
+ for opt in torch_opt.chained_optimizers:
213
199
  self.map_fp16_tp_fp32_param(opt)
214
200
 
215
- if not isinstance(torch_opt, torch.optim.Optimizer):
201
+ if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
216
202
  torch_opt.state = {}
217
- for opt in mix_prec_opt.chained_optimizers:
203
+ for opt in torch_opt.chained_optimizers:
218
204
  torch_opt.state.update(opt.optimizer.state)
219
205
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
220
206
 
221
207
 
222
208
  class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
223
- def fetch_mv(self, monitor, torch_opt, params2name):
224
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
209
+ def get_group_index(self, torch_opt):
210
+ bit16_groups = torch_opt.bf16_groups
211
+ param2group = defaultdict()
212
+ for group_idx, bit16_group in enumerate(bit16_groups):
213
+ for param in bit16_group:
214
+ param2group[param] = group_idx
215
+ return param2group
216
+
217
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
218
+ param2group = self.get_group_index(torch_opt)
219
+ exp_avg_dict = defaultdict(float)
220
+ exp_avg_sq_dict = defaultdict(float)
221
+ update_dict = defaultdict()
222
+ ratio_dict = defaultdict()
223
+
224
+ param_slice_mappings = torch_opt.state_dict()['param_slice_mappings']
225
+ for param, name in params2name.items():
226
+ group_idx = param2group[param]
227
+ state = torch_opt.optimizer.state[torch_opt.fp32_groups_flat_partition[group_idx]]
228
+ if state.get('exp_avg', None) is None:
229
+ logger.warning(f"optimizer state is None. Something is wrong if this is not the first step")
230
+ break
231
+ param_slice_mapping = param_slice_mappings[group_idx]
232
+ hp_address = param_slice_mapping.get(torch_opt.param_names[param])
233
+ if hp_address is None:
234
+ continue
235
+ start = hp_address.start
236
+ numel = hp_address.numel
225
237
 
238
+ if monitor.mv_distribution:
239
+ exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel)
240
+ exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel)
241
+ if monitor.mg_direction:
242
+ exp_avg_dict[name] = state['exp'].narrow(0, start, numel)
243
+ if monitor.ur_distribution:
244
+ if len(torch_opt.param_groups) > 1:
245
+ logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
246
+ if 'step' in state:
247
+ step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
248
+ elif 'step' in torch_opt.param_groups[0]:
249
+ step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
250
+ else:
251
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
252
+ continue
253
+ exp_avg = state['exp_avg'].narrow(0, start, numel)
254
+ exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
255
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
256
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
257
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
258
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
259
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
260
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
261
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
262
+
226
263
 
227
264
  class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
228
- def get_param_index(self, params2name, name2index):
229
- mix_prec_opt = OptimizerMon.wrapped_optimizer
230
- fp16_groups = mix_prec_opt.fp16_partitioned_groups
265
+ def get_param_index(self, params2name, name2index, torch_opt):
266
+ fp16_groups = torch_opt.fp16_partitioned_groups
231
267
  name2indices = defaultdict()
232
268
  index_length = defaultdict()
233
269
  index = 0
@@ -246,13 +282,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
246
282
 
247
283
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
248
284
  self.is_stage3 = True
249
- mix_prec_opt = OptimizerMon.wrapped_optimizer
250
- fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
285
+ fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
251
286
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
252
287
 
253
288
 
254
289
  class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
255
-
256
290
  @staticmethod
257
291
  def get_group_index(fp32_length, world_size, index):
258
292
  for i in range(len(fp32_length) - 1):
@@ -265,12 +299,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
265
299
  return sub_interval_start, min(sub_index, world_size - 1)
266
300
  return fp32_length[-1], 0
267
301
 
268
- def get_param_index(self, params2name, name2index):
269
- mix_prec_opt = OptimizerMon.wrapped_optimizer
270
- padding = mix_prec_opt.groups_padding
302
+ def get_param_index(self, params2name, name2index, torch_opt):
303
+ padding = torch_opt.groups_padding
271
304
  world_size = dist.get_world_size()
272
305
  fp32_length = [0]
273
- for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
306
+ for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
274
307
  fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
275
308
 
276
309
  bf16_groups = []
@@ -278,7 +311,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
278
311
  index_length = defaultdict()
279
312
  index = 0
280
313
  idx = 0
281
- for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
314
+ for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
282
315
  bf16_groups.extend(bf16_group)
283
316
  for param in bf16_group:
284
317
  param_length = len(param.flatten())
@@ -286,7 +319,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
286
319
  index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
287
320
  index += param_length
288
321
  idx += 1
289
- group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
322
+ group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
290
323
  for _, name in params2name.items():
291
324
  name_index = name2index[name]
292
325
  start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
@@ -300,8 +333,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
300
333
  return name2indices
301
334
 
302
335
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
303
- mix_prec_opt = OptimizerMon.wrapped_optimizer
304
- fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
336
+ fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
305
337
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
306
338
 
307
339
 
@@ -312,22 +344,23 @@ class DummyOptimizerMon(OptimizerMon):
312
344
 
313
345
  class OptimizerMonFactory:
314
346
  _optimizer_mon_map = {
315
- "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
316
- "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
317
- "Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
- "Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
319
- "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
320
- "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
321
- "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
347
+ "FP32Optimizer": MegatronFP32OptimizerMon,
348
+ "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
349
+ "DistributedOptimizer": MegatronDistributedOptimizerMon,
350
+ "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
351
+ "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
352
+ "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
353
+ "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
322
354
  "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
323
- "unknown": DummyOptimizerMon
355
+ "Adam": DummyOptimizerMon
324
356
  }
325
357
 
326
358
  @staticmethod
327
- def create_optimizer_mon(opt_ty: str):
328
- if not opt_ty:
329
- return DummyOptimizerMon()
330
- optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
331
- if not optimizer_mon_class:
332
- raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
333
- return optimizer_mon_class()
359
+ def create_optimizer_mon(optimizer):
360
+ # auto replace opt_ty
361
+ optimizer_class = optimizer.__class__.__name__
362
+ if optimizer_class == "ChainedOptimizer":
363
+ optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
364
+
365
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
366
+ return optimizer_mon_class(), optimizer_class
@@ -92,7 +92,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
92
92
  if errors:
93
93
  logger.info(errors)
94
94
  else:
95
- logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.')
95
+ logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitored.')
96
96
 
97
97
 
98
98
  def assert_equal(a, b):
@@ -25,7 +25,7 @@ import torch
25
25
  from msprobe.core.common.const import MonitorConst, Const
26
26
  from msprobe.pytorch.common.log import logger
27
27
  from msprobe.core.common.utils import is_int
28
- from msprobe.core.common.file_utils import check_file_or_directory_path
28
+ from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
29
29
 
30
30
 
31
31
  device = "cpu"
@@ -36,7 +36,7 @@ except ImportError:
36
36
  if torch.cuda.is_available():
37
37
  device = "cuda"
38
38
 
39
- NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
39
+ NAN_TENSOR_ON_DEVICE = None
40
40
  FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
41
41
  FILE_NAME_MAX_LENGTH = 255
42
42
  DIRECTORY_MAX_LENGTH = 4096
@@ -57,6 +57,13 @@ def get_output_base_dir():
57
57
  return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
58
58
 
59
59
 
60
+ def get_nan_tensor():
61
+ global NAN_TENSOR_ON_DEVICE
62
+ if not NAN_TENSOR_ON_DEVICE:
63
+ NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
64
+ return NAN_TENSOR_ON_DEVICE
65
+
66
+
60
67
  def filter_special_chars(func):
61
68
  @wraps(func)
62
69
  def func_level(msg):
@@ -82,48 +89,6 @@ def get_param_struct(param):
82
89
  return res
83
90
 
84
91
 
85
- def is_recomputation():
86
- """Check if the current operation is in the re-computation phase.
87
-
88
- This function inspects the current call stack to indicate whether the current operation is in the
89
- re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
90
- megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
91
- mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
92
- file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
93
- 'torch/_tensor.py' file.
94
-
95
- Returns:
96
- bool: True if in the re-computation phase, False otherwise.
97
- """
98
- backward_function_indices = []
99
- call_stack = inspect.stack()
100
-
101
- # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
102
- for frame_info in call_stack:
103
- if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
104
- del call_stack
105
- return True
106
-
107
- # Identify indices in the call stack where the specific function is being executed
108
- for idx, frame_info in enumerate(call_stack):
109
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
110
- backward_function_indices.append(idx)
111
-
112
- # Check if the execution is within 'torch/autograd/function.py' file
113
- for idx in backward_function_indices:
114
- # The Megatron and MindSpeed L0&L1 scenes
115
- if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
116
- del call_stack
117
- return True
118
- # The latest MindSpeed L2 and ModelLink scenes
119
- if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
120
- del call_stack
121
- return True
122
-
123
- del call_stack
124
- return False
125
-
126
-
127
92
  def validate_ops(ops):
128
93
  if not isinstance(ops, list):
129
94
  raise TypeError("ops should be a list")
@@ -140,6 +105,15 @@ def validate_ops(ops):
140
105
  return valid_ops
141
106
 
142
107
 
108
+ def validate_ndigits(ndigits):
109
+ if not ndigits:
110
+ return
111
+ if not is_int(ndigits) or ndigits <= 0:
112
+ raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
113
+ if ndigits > MonitorConst.MAX_NDIGITS:
114
+ raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
115
+
116
+
143
117
  def validate_ranks(ranks):
144
118
  if not isinstance(ranks, list):
145
119
  raise TypeError("module_ranks should be a list")
@@ -241,9 +215,17 @@ def validate_step_count_per_record(step_count_per_record):
241
215
  raise ValueError("step_count_per_record must smaller than 1e6")
242
216
 
243
217
 
218
+ def validate_dynamic_on(dynamic_on):
219
+ if not isinstance(dynamic_on, bool):
220
+ raise TypeError('dynamic_on should be a bool')
221
+
222
+
244
223
  def validate_config(config):
245
224
  config['ops'] = validate_ops(config.get('ops', []))
246
225
 
226
+ ndigits = config.get('ndigits')
227
+ validate_ndigits(ndigits)
228
+
247
229
  eps = config.get('eps', 1e-8)
248
230
  if not isinstance(eps, float):
249
231
  raise TypeError("eps should be a float")
@@ -281,9 +263,20 @@ def validate_config(config):
281
263
  step_count_per_record = config.get('step_count_per_record', 1)
282
264
  validate_step_count_per_record(step_count_per_record)
283
265
 
266
+ config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
267
+ MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
268
+ config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
269
+ MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
270
+ MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
271
+ config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
272
+ MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
273
+
284
274
  squash_name = config.get('squash_name', True)
285
275
  validate_squash_name(squash_name)
286
276
 
277
+ dynamic_on = config.get('dynamic_on', False)
278
+ validate_dynamic_on(dynamic_on)
279
+
287
280
  if not targets:
288
281
  if xy_distribution:
289
282
  config["all_xy"] = True
@@ -292,6 +285,8 @@ def validate_config(config):
292
285
 
293
286
  def time_str2time_digit(time_str):
294
287
  time_format = '%b%d_%H-%M-%S'
288
+ if not isinstance(time_str, str):
289
+ raise TypeError(f"time_str:{time_str} should be a str")
295
290
  try:
296
291
  time_digit = datetime.strptime(time_str, time_format)
297
292
  except Exception as e:
@@ -319,3 +314,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
319
314
  if start_ok and end_ok:
320
315
  result[rank] = os.path.join(monitor_path, dirname)
321
316
  return result
317
+
318
+
319
+ def chmod_tensorboard_dir(path):
320
+ """
321
+ format配置为tensorboard时,需要补充文件权限设置
322
+ """
323
+ try:
324
+ recursive_chmod(path)
325
+ except Exception as e:
326
+ logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
327
+
328
+
329
+ def validate_set_monitor(grad_acc_steps, start_iteration):
330
+ """
331
+ validate parameters of set_monitor.
332
+ """
333
+ grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
334
+ MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
335
+
336
+ start_iteration = validate_int_arg(start_iteration, "start_iteration",
337
+ MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
338
+ return grad_acc_steps, start_iteration
339
+
340
+
341
+ def validate_int_arg(value, name, minimum, default_value):
342
+ """Validate int args, if any exception occurs, use the default value."""
343
+ if value is None:
344
+ return default_value
345
+ try:
346
+ if not is_int(value):
347
+ raise TypeError(f"{name} must be int")
348
+ if value < minimum:
349
+ raise ValueError(f"{name} must greater than {minimum}")
350
+ except Exception as e:
351
+ value = default_value
352
+ logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
353
+ return value
@@ -125,8 +125,6 @@ class Saver:
125
125
 
126
126
  def write_summary_csv(self, test_result):
127
127
  test_rows = []
128
- if self.stack_info:
129
- test_rows[0].append(self.COLUMN_STACK_INFO)
130
128
 
131
129
  check_op_str_pattern_valid(test_result.api_name)
132
130
  df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
@@ -16,6 +16,7 @@
16
16
  import json
17
17
  import os
18
18
  import time
19
+ import multiprocessing
19
20
  from multiprocessing import Pool
20
21
 
21
22
  import torch
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
52
53
  return
53
54
  if dump_path is None:
54
55
  logger.error("Please set dump_path when dump_mode is config!")
56
+ raise DispatchException("Please set dump_path when dump_mode is config!")
55
57
  check_file_or_directory_path(dump_path, True)
56
58
 
57
59
  self.device_id = torch_npu._C._npu_getDevice()
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
85
87
  self.get_ops(yaml_path)
86
88
 
87
89
  self.lock = None
90
+ max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
+ if process_num > max_process_num:
92
+ logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
+ raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
+ f'but got {process_num}!')
88
95
  if process_num > 0:
89
96
  self.pool = Pool(process_num)
90
97
  if debug:
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
115
122
  if len(json_line_data) == 0:
116
123
  break
117
124
  msg = json.loads(json_line_data)
125
+ if len(msg) < 2:
126
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
118
127
  self.all_summary[msg[0]] = msg[1]
119
128
  fp_handle.close()
120
129
 
@@ -19,6 +19,8 @@ import os
19
19
  from datetime import datetime, timezone
20
20
 
21
21
  import torch
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.decorator import recursion_depth_decorator
22
24
  from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
23
25
  from msprobe.pytorch.common.log import logger
24
26
 
@@ -91,6 +93,7 @@ def support_basic_type(data):
91
93
  return False
92
94
 
93
95
 
96
+ @recursion_depth_decorator("dump_data")
94
97
  def dump_data(data, prefix, dump_path):
95
98
  if isinstance(data, (tuple, list)) and data:
96
99
  for i, item in enumerate(data):
@@ -27,8 +27,10 @@ else:
27
27
  pta_cpu_device = torch.device("cpu")
28
28
 
29
29
  from msprobe.core.common.const import CompareConst
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
  from msprobe.pytorch.common.log import logger
31
32
 
33
+
32
34
  cpu_device = torch._C.device("cpu")
33
35
  COLOR_RED = '\033[31m'
34
36
  COLOR_GREEN = '\033[32m'
@@ -85,6 +87,7 @@ def get_callstack():
85
87
  return callstack
86
88
 
87
89
 
90
+ @recursion_depth_decorator("data_to_cpu")
88
91
  def data_to_cpu(data, deep, data_cpu):
89
92
  global cpu_device
90
93
  list_cpu = []
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
45
45
 
46
46
  @catch_exception
47
47
  def default(self, line=""):
48
- self.util.execute_command(line)
49
- return False
50
-
51
- @catch_exception
52
- def do_run(self, line=""):
53
- self.util.execute_command(line)
48
+ self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
54
49
 
55
50
  @catch_exception
56
51
  def do_vc(self, line=""):
@@ -119,6 +119,7 @@ class Util:
119
119
 
120
120
  @staticmethod
121
121
  def deal_with_dir_or_file_inconsistency(output_path):
122
+ logger.warning(f"Trying to delete {output_path}")
122
123
  remove_path(output_path)
123
124
  raise ParseException("Inconsistent directory structure or file.")
124
125
 
@@ -264,7 +265,7 @@ class Util:
264
265
  match = re_pattern.match(name)
265
266
  if not match:
266
267
  continue
267
- if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
268
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
268
269
  continue
269
270
  file_list[name] = gen_info_func(name, match, file["root"])
270
271
  return file_list