mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -21,16 +21,15 @@ from datetime import datetime
21
21
 
22
22
  import pytz
23
23
  import mindspore as ms
24
- import mindspore.common.dtype as mstype
25
- from mindspore import Tensor, ops, mint
24
+ from mindspore import Tensor, mint
26
25
  from mindspore import nn, _no_grad
27
26
  from mindspore.communication import get_rank
28
27
 
29
28
  from msprobe.core.common.log import logger
30
29
  from msprobe.core.common.const import MonitorConst
31
- from msprobe.core.common.file_utils import load_json
30
+ from msprobe.core.common.file_utils import load_json, save_json
32
31
  from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
33
- is_skip_step, get_metrics, get_single_metrics
32
+ is_skip_step, get_metrics, get_single_metrics, get_target_output_dir
34
33
  from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
35
34
  from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
36
35
  CSVWriterWithAD, BaseWriterWithAD, WriterInput
@@ -108,6 +107,10 @@ class ModuleHookContext:
108
107
  elif key_name in ['input', 'input_grad']:
109
108
  self.ignore_in = True
110
109
 
110
+ def reset(self):
111
+ self.actv.clear()
112
+ self.actvgrad.clear()
113
+
111
114
 
112
115
  start_step = 0
113
116
 
@@ -116,7 +119,6 @@ start_step = 0
116
119
  class OptimizerContext:
117
120
  def __init__(self) -> None:
118
121
  self.step = start_step
119
- self.param_effective_rank = defaultdict(float)
120
122
  self.param_mg_direction = defaultdict(float)
121
123
  self.param_adam_update = defaultdict()
122
124
  self.param_adam_ratio = defaultdict()
@@ -131,6 +133,7 @@ class OptimizerContext:
131
133
  def reset(self) -> None:
132
134
  self.param_mg_direction.clear()
133
135
  self.param_adam_update.clear()
136
+ self.param_adam_ratio.clear()
134
137
  self.param_weight_grad.clear()
135
138
  self.param_exp_avg.clear()
136
139
  self.exp_avg_metric.clear()
@@ -179,50 +182,100 @@ class CommunicationContext:
179
182
 
180
183
  class TrainerMon:
181
184
  def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
185
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
+ self.config_file_path = config_file_path
187
+ self.process_group = process_group
188
+ self.params_have_main_grad = params_have_main_grad
189
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
190
+ self.config = load_json(config_file_path)
191
+ validate_config(self.config)
192
+
193
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
194
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
195
+ self.unique_id = str(uuid.uuid4())[:8]
196
+ self.output_base_dir = get_output_base_dir()
197
+ time_tags = self.config.get("append_output", [])
198
+ try:
199
+ self.rank = get_rank()
200
+ if time_tags:
201
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
202
+ if str(self.rank) in output_append_dirs:
203
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
204
+ logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}")
205
+ else:
206
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
207
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
208
+ except Exception as e:
209
+ self.rank = 0
210
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
211
+
212
+ self.pp_stage = 0
213
+ self.group_mates = [0]
214
+
215
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
216
+ self.model = None
217
+ self.vpp = False
218
+ self.dp_group = None
219
+ self.tp_group = None
220
+ self.micro_batch_number = 1
221
+
222
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
182
223
  self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
183
224
  self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
184
225
  self.optimizer_context = defaultdict(OptimizerContext)
185
226
  self.cc_context = defaultdict(CommunicationContext)
186
227
  self.grad_context = GradContext()
187
- self.params_have_main_grad = params_have_main_grad
188
228
  self.handles = defaultdict(list)
189
- self.config = load_json(config_file_path)
190
- validate_config(self.config)
229
+ self.param2name = defaultdict(str)
230
+ self.name2index = defaultdict()
231
+ self.name2indices = defaultdict()
232
+ self.name2param = {}
233
+ self.duplicate_param = {}
234
+ self.name2tag = {}
235
+ self.param_name_call_id = {}
236
+ self.call_id = 0
237
+ self.module_struct = defaultdict(dict)
238
+ self.grad_accs = []
239
+ self.weight_hooked = False
240
+ self.optimizer_hooked = False
241
+ self.param_registered = False
242
+ self.struct_printed = False
243
+
244
+ # 动静态区分
245
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
246
+ if self.dynamic_enable:
247
+ logger.warning(f"DYNAMIC_MONITOR is set, "
248
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
249
+ self.monitoring = False
250
+ else:
251
+ self.set_config()
252
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
253
+ if self.collect_times > 0:
254
+ self.monitoring = True
191
255
 
256
+ def set_config(self):
192
257
  self.start_step = self.config.get("start_step", 0)
193
258
  self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
194
259
  self.step_interval = self.config.get("step_interval", 1)
195
- self.has_collect_times = 0
196
-
197
- # monitor target in module, such as layer, weight, grad
260
+ self.has_collect_times = 0 # 重设采集计数器
261
+ self.print_struct = self.config.get("print_struct", False)
198
262
  self.targets = self.config.get("targets", None)
199
263
  self.is_select = self.config.get("is_select", False)
200
264
  self.module_rank_list = self.config.get("module_ranks", [])
201
- # only csv supported in mindspore
202
- self.format = self.config.get('format', MonitorConst.CSV)
265
+ self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
203
266
  self.eps = self.config.get('eps', 1e-8)
204
- # monitor mean/max/norm/min/nan...
205
- self.ops = self.config.get('ops', [])
267
+ self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan...
206
268
  self.ndigits = self.config.get('ndigits', 6)
207
269
  self.all_xy = self.config.get('all_xy', False)
208
- # module input/output input_grad/output_grad
209
270
  self.xy_distribution = self.config.get('xy_distribution', False)
210
- # activation forward
211
271
  self.forward_only = self.config.get('forward_only', False)
212
- # activation backward
213
272
  self.backward_only = self.config.get('backward_only', False)
214
- # update vector and ratio vector of adam
215
- self.ur_distribution = self.config.get('ur_distribution', False)
216
- # m/v of adam
217
- self.mv_distribution = self.config.get("mv_distribution", False)
218
- # weight grad
273
+ self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam
274
+ self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam
219
275
  self.wg_distribution = self.config.get("wg_distribution", False)
220
- # optimizer param
221
276
  self.param_distribution = self.config.get("param_distribution", False)
222
- # main grad direction
223
- self.mg_direction = self.config.get('mg_direction', False)
224
- # communication ops
225
- self.cc_distribution = self.config.get("cc_distribution", {})
277
+ self.mg_direction = self.config.get('mg_direction', False) # main grad direction
278
+ self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
226
279
  if not self.cc_distribution.get('enable', False):
227
280
  self.cc_log_only = False
228
281
  else:
@@ -230,140 +283,173 @@ class TrainerMon:
230
283
  self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
231
284
  self.cc_logged_stack = defaultdict(set)
232
285
  self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
233
- self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
234
- api_register.redirect_api()
235
286
  self.common_info()
236
287
 
288
+ # 初始化AnomalyData工厂
237
289
  alert_setting = self.config.get('alert', {"rules": []})
238
290
  self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
239
-
240
- local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
241
-
242
- cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
243
- unique_id = str(uuid.uuid4())[:8]
244
- output_base_dir = get_output_base_dir()
245
-
246
- time_tags = self.config.get("append_output", [])
247
- if time_tags:
248
- output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
249
- try:
250
- rank = get_rank()
251
- except Exception as e:
252
- rank = 0
253
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
254
- logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
255
- pp_stage = 0
256
- group_mates = [0]
257
- else:
258
- if time_tags and str(rank) in output_append_dirs:
259
- tensorboard_dir = outputappenddirs[str(rank)]
260
- logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
261
- else:
262
- tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
263
- pp_stage = 0
264
- group_mates = [0]
265
-
266
- self.rank = rank
267
-
268
- # 初始化AnomalyData工厂
269
291
  self.anomaly_data_factory = None
270
292
  if alert_setting.get('dump', False):
271
- self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
293
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
272
294
 
295
+ # 初始化writer, 创建输出目录
273
296
  if self.format not in FORMAT_MAPPING:
274
297
  logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
275
298
  self.format = MonitorConst.CSV
276
299
  writer = FORMAT_MAPPING[self.format]
277
300
  self.step_count_per_record = self.config.get('step_count_per_record', 1)
278
-
279
301
  self.summary_writer = writer(
280
302
  WriterInput(
281
- tensorboard_dir,
303
+ self.tensorboard_dir,
282
304
  self.alert_rules,
283
- unique_id,
305
+ self.unique_id,
284
306
  self.anomaly_data_factory,
285
307
  self.ndigits,
286
308
  self.step_count_per_record
287
309
  )
288
310
  )
289
311
 
290
- self.micro_batch_number = 1
291
-
292
- self.model = None
293
- self.weight_hooked = False
294
- self.optimizer_hooked = False
295
- self.param_registered = False
296
- self.vpp = False
297
- self.dp_group = None
298
- self.tp_group = None
299
- self.enable_megatron = False
300
-
301
- self.param2name = defaultdict(str)
302
- self.name2index = defaultdict()
303
- self.name2indices = defaultdict()
304
- self.name2param = {}
305
- self.param_name_call_id = {}
306
- self.duplicate_param = {}
307
- self.name2tag = {}
308
- self.call_id = 0
309
- self.grad_accs = []
310
- self.handles = defaultdict(list)
311
-
312
- self.print_struct = self.config.get("print_struct", False)
313
- self.struct_printed = False
314
- self.module_struct = defaultdict(dict)
312
+ def common_info(self):
313
+ if not self.xy_distribution:
314
+ logger.info("> module input/output input_grad/output_grad is not monitored. ")
315
+ if self.forward_only:
316
+ logger.info("> only module forward is monitored. ")
317
+ if not self.ur_distribution:
318
+ logger.info("> update vector and ratio vector of adam is not monitored. ")
319
+ if not self.mv_distribution:
320
+ logger.info("> momentum and variance of adam is not monitored. ")
321
+ if not self.wg_distribution:
322
+ logger.info("> weight grad of specified module is not monitored. ")
323
+ if not self.mg_direction:
324
+ logger.info('> grad and momentum direction will not be compared.')
325
+ if not self.cc_distribution.get('enable', False):
326
+ logger.info("> cc operator is not monitored.")
315
327
 
316
- # Start
317
328
  def set_monitor(
318
329
  self,
319
330
  model,
331
+ optimizer,
320
332
  grad_acc_steps=1,
321
- optimizer=None,
322
333
  tp_group=None,
323
334
  dp_group=None,
324
- start_iteration=0):
335
+ start_iteration=0
336
+ ):
325
337
  global start_step
326
338
  start_step = start_iteration
327
- logger.info(f'grad acc steps {grad_acc_steps}')
328
- self.hook_optimizer(optimizer)
329
339
  self.micro_batch_number = grad_acc_steps
330
340
  self.dp_group = dp_group
331
341
  self.tp_group = tp_group
342
+ self.hook_step_final(optimizer)
343
+ if not isinstance(model, list):
344
+ model = [model]
345
+ self.model = model
346
+ if len(model) > 1:
347
+ self.vpp = True
348
+ logger.info('vpp enabled')
349
+ if not self.dynamic_enable:
350
+ self.register_hooks(optimizer)
351
+
352
+ def hook_step_final(self, optimizer):
353
+ def step_final_hook(optimizer, *args, **kwargs):
354
+ context = self.optimizer_context[optimizer]
355
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
356
+ if self.monitoring:
357
+ module_rank_valid = self.is_target_rank()
358
+ step_condition = (context.step >= self.start_step and (
359
+ context.step - self.start_step) % self.step_interval == 0)
360
+ if module_rank_valid and step_condition:
361
+ self.has_collect_times += 1
362
+ self.write_xy_tb(context.step)
363
+ self.write_grad_tb(context.step)
364
+ self.write_mv_tb(context)
365
+ self.write_param_tb(context)
366
+
367
+ if context.metric_dict:
368
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
369
+ context.metric_dict.clear()
370
+
371
+ self.summary_writer.clear_anomalies()
372
+ self.call_id = 0
373
+ self.param_name_call_id.clear()
374
+
375
+ if self.has_collect_times >= self.collect_times:
376
+ self._remove_all_hooks_final(optimizer)
332
377
 
333
- self.hook_modules(model, grad_acc_steps)
334
- self._patch_grad_sync()
378
+ context.step += 1
379
+ self.dynamic_monitor(optimizer)
335
380
 
336
- """
337
- Start
338
- """
339
- def hook_optimizer(self, optimizer):
340
- rank_id = str(get_rank())
341
- if self.optimizer_hooked:
381
+ optimizer.register_forward_hook(step_final_hook)
382
+ return
383
+
384
+ def dynamic_monitor(self, optimizer):
385
+ """
386
+ If dynamic monitor enabled and config.json updated,
387
+ remove hooks and register new hooks according to new configuration.
388
+ """
389
+ context = self.optimizer_context[optimizer]
390
+ if not self.dynamic_enable:
342
391
  return
392
+ try:
393
+ # 如果文件时间戳没变, 可以不读取节省时间
394
+ config_timestamp = os.path.getmtime(self.config_file_path)
395
+ if config_timestamp == self.config_timestamp:
396
+ return
397
+ # 更新config文件最新修改时间戳
398
+ self.config_timestamp = config_timestamp
399
+ config = load_json(self.config_file_path)
400
+ except Exception as e:
401
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
402
+ return
403
+
404
+ if config.get("dynamic_on", False):
405
+ try:
406
+ validate_config(config)
407
+ self.config = config
408
+ self.set_config()
409
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
410
+ logger.warning(f"config is updated at step{context.step - 1}, "
411
+ f"will start new hook at step{context.step}.")
412
+ except Exception as e:
413
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
414
+ return
415
+
416
+ self._remove_all_hooks()
417
+ self.register_hooks(optimizer)
418
+
419
+ def register_hooks(self, optimizer):
420
+ self._register_param_name()
421
+ self.hook_modules()
422
+ self.hook_optimizer(optimizer)
423
+ self._patch_grad_sync()
424
+ if self.cc_distribution.get('enable', False):
425
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
426
+ api_register.redirect_api()
427
+ self.monitoring = True
343
428
 
429
+ def hook_modules(self):
344
430
  if not self.is_target_rank():
345
431
  return
432
+ module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
346
433
 
347
- m_list = []
348
- v_list = []
349
- param_list = []
350
- grad_names = []
351
- for param in optimizer.get_parameters():
352
- if MonitorConst.EXP_AVG_SQ in param.name:
353
- v_list.append(param)
354
- elif MonitorConst.EXP_AVG in param.name:
355
- m_list.append(param)
356
- else:
357
- param_list.append(param)
358
- grad_names.append(param.name)
434
+ for key in module_in_all_stage:
435
+ struct = self.targets.pop(key)
436
+ self.targets.update(
437
+ {f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
359
438
 
360
- """
361
- grad reduced
362
- m/v
363
- """
439
+ hooked_count = 0
440
+ for vpp_stage, model_chunk in enumerate(self.model):
441
+ if not isinstance(model_chunk, nn.Cell):
442
+ logger.info("Target Model is not Cell")
443
+ continue
444
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
445
+ targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
446
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
447
+ logger.info(f"> {hooked_count} modules are monitored.")
448
+
449
+ def hook_optimizer(self, optimizer):
364
450
  def optimizer_pre_hook_function(opt, grad_names, gradients):
365
451
  context = self.optimizer_context[opt]
366
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
452
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
367
453
  self.collect_times):
368
454
  return
369
455
  gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
@@ -402,46 +488,64 @@ class TrainerMon:
402
488
  context.metric_dict = metric_dict
403
489
  return
404
490
 
405
- def optimizer_post_hook_function(opt, args, gradients, outputs):
406
- context = self.optimizer_context[opt]
407
- step_skip = is_skip_step(context.step, self.start_step, self.step_interval, \
408
- self.has_collect_times, self.collect_times)
409
- if step_skip:
410
- context.step += 1
411
- return
412
- self.write_xy_tb(context.step)
413
- self.write_grad_tb(context.step)
414
- self.write_mv_tb(context)
415
- self.write_param_tb(context)
416
-
417
- if context.metric_dict:
418
- self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
419
- context.metric_dict.clear()
420
- self.has_collect_times += 1
421
- context.step += 1
422
- if self.anomaly_data_factory:
423
- self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
424
- self.summary_writer.clear_anomalies()
425
- self.call_id = 0
426
- self.param_name_call_id.clear()
427
- return
428
-
429
491
  def optimizer_pre_hook_wrapper(func, grad_names):
430
492
  def wrapper(opt, gradients):
431
493
  return func(opt, grad_names, gradients)
432
494
  return wrapper
433
495
 
434
- def optimizer_post_hook_wrapper(func, args=None):
435
- def wrapper(opt, gradients, outputs):
436
- return func(opt, args, gradients, outputs)
437
- return wrapper
496
+ if self.optimizer_hooked or not self.is_target_rank():
497
+ return
438
498
 
439
- optimizer.register_forward_pre_hook(optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
440
- optimizer.register_forward_hook(optimizer_post_hook_wrapper(optimizer_post_hook_function))
499
+ m_list = []
500
+ v_list = []
501
+ param_list = []
502
+ grad_names = []
503
+ for param in optimizer.get_parameters():
504
+ if MonitorConst.EXP_AVG_SQ in param.name:
505
+ v_list.append(param)
506
+ elif MonitorConst.EXP_AVG in param.name:
507
+ m_list.append(param)
508
+ elif param.name in ['global_step', 'learning_rate']:
509
+ pass
510
+ else:
511
+ param_list.append(param)
512
+ grad_names.append(param.name)
441
513
 
514
+ handle = optimizer.register_forward_pre_hook(
515
+ optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
516
+ self.handles['optimizer'].append(handle)
442
517
  self.optimizer_hooked = True
443
518
  return
444
519
 
520
+ def generate_wgrad_metrics(self):
521
+ if not self.wg_distribution:
522
+ return {}, {}
523
+
524
+ if self.weight_hooked:
525
+ try:
526
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
527
+ except Exception as e:
528
+ logger.warning(f"An error occurred while generating wgrad pre metrics")
529
+ return {}, {}
530
+
531
+ grad_dict = {}
532
+ for param, name in self.param2name.items():
533
+ if self.duplicate_param.get(name, False):
534
+ continue
535
+ grad = param.main_grad if self.params_have_main_grad else param.grad
536
+ if grad is None:
537
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
538
+ continue
539
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
540
+ self._register_param_call_id("hook_optimizer", tag)
541
+ grad_dict[tag] = grad
542
+ try:
543
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
544
+ except Exception as e:
545
+ logger.warning(f"An error occurred while generating wgrad post metrics")
546
+ return {}, {}
547
+ return self.grad_context.post, self.grad_context.pre
548
+
445
549
  def write_xy_tb(self, step):
446
550
  if not self.xy_distribution:
447
551
  return
@@ -468,121 +572,27 @@ class TrainerMon:
468
572
  if not self.wg_distribution:
469
573
  return
470
574
 
471
- if self.enable_megatron:
472
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
473
- else:
474
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
575
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
475
576
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
476
577
 
477
- def common_info(self):
478
- if not self.xy_distribution:
479
- logger.info("> module input/output input_grad/output_grad is not monitored. ")
480
- if self.forward_only:
481
- logger.info("> only module forward is monitored. ")
482
- if not self.ur_distribution:
483
- logger.info("> update vector and ratio vector of adam is not monitored. ")
484
- if not self.mv_distribution:
485
- logger.info("> momentum and variance of adam is not monitored. ")
486
- if not self.wg_distribution:
487
- logger.info("> weight grad of specified module is not monitored. ")
488
- if not self.mg_direction:
489
- logger.info('> grad and momentum direction will not be compared.')
490
- if not self.cc_distribution.get('enable', False):
491
- logger.info("> cc operator is not monitored.")
492
-
493
578
  def is_target_rank(self):
494
- rank_id = str(get_rank())
495
- if self.module_rank_list and (rank_id not in self.module_rank_list):
579
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
496
580
  return False
497
581
  return True
498
582
 
499
- def hook_modules(self, model, grad_acc_steps):
500
- if not self.is_target_rank():
501
- return
502
- if not isinstance(model, list):
503
- model = [model]
504
- self.model = model # list
505
- self._register_param_name(model)
506
- self.micro_batch_number = grad_acc_steps
507
- module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
508
-
509
- for key in module_in_all_stage:
510
- struct = self.targets.pop(key)
511
- self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
512
-
513
- hooked_count = 0
514
- for vpp_stage, model_chunk in enumerate(model):
515
- if not isinstance(model_chunk, nn.Cell):
516
- logger.info("Target Model is not Cell")
517
- continue
518
- vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
519
- targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
520
- hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
521
- logger.info(f"> {hooked_count} modules are monitored.")
522
-
523
583
  def build_tbtag_tensor_map(self, module_name, tag, tensor):
524
- rank_id = str(get_rank())
525
584
  metrics = {}
526
- key = get_summary_writer_tag_name(module_name, tag, rank_id)
585
+ key = get_summary_writer_tag_name(module_name, tag, str(self.rank))
527
586
  if isinstance(tensor, Tensor):
528
587
  self._register_param_call_id("_hook_module", key)
529
588
  metrics[key] = tensor
530
589
  return metrics
531
590
 
532
- def generate_wgrad_metrics(self):
533
- if not self.wg_distribution:
534
- return {}, {}
535
-
536
- if self.weight_hooked:
537
- try:
538
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
539
- except Exception as e:
540
- logger.warning(f"An error occurred while generating wgrad pre metrics")
541
- return {}, {}
542
-
543
- grad_dict = {}
544
- for param, name in self.param2name.items():
545
- if self.duplicate_param.get(name, False):
546
- continue
547
- grad = param.main_grad if self.params_have_main_grad else param.grad
548
- if grad is None:
549
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
550
- continue
551
- tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
552
- self._register_param_call_id("hook_optimizer", tag)
553
- grad_dict[tag] = grad
554
- try:
555
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
556
- except Exception as e:
557
- logger.warning(f"An error occurred while generating wgrad post metrics")
558
- return {}, {}
559
- return self.grad_context.post, self.grad_context.pre
560
-
561
- def _register_param_name(self, model):
562
- if self.param_registered:
563
- return
564
-
565
- if len(model) > 1:
566
- self.vpp = True
567
- logger.info('vpp enabled')
568
-
569
- for vpp_stage, model_chunk in enumerate(model):
591
+ def _register_param_name(self):
592
+ for vpp_stage, model_chunk in enumerate(self.model):
570
593
  prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
571
594
  self._register_chunk(model_chunk, prefix)
572
595
 
573
- self.param_registered = True
574
-
575
- def _is_target_param(self, param_name, param, prefix):
576
- if not self.targets:
577
- return True
578
- squash_name = prefix + squash_param_name(param_name)
579
- name = prefix + param_name
580
- for target in self.targets.keys():
581
- if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
582
- setattr(param, "zero_out_wgrad", True)
583
- return True
584
- return False
585
-
586
596
  def _register_chunk(self, model_chunk, prefix):
587
597
  index = 0
588
598
  for param in model_chunk.get_parameters():
@@ -607,17 +617,6 @@ class TrainerMon:
607
617
  }
608
618
  index += 1
609
619
 
610
- def _is_target_module(self, module_name, targets, vpp_stage):
611
- if self.all_xy or self.print_struct:
612
- return vpp_stage + squash_param_name(module_name)
613
- for pattern in [
614
- vpp_stage + squash_param_name(module_name),
615
- vpp_stage + module_name,
616
- ]:
617
- if pattern in targets:
618
- return pattern
619
- return ""
620
-
621
620
  def _hook_module(self, target_names, module, vpp_stage=''):
622
621
  if not isinstance(module, nn.Cell):
623
622
  # nothing to hook
@@ -637,7 +636,7 @@ class TrainerMon:
637
636
  return
638
637
  if not module.training:
639
638
  return
640
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
639
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
641
640
  self.collect_times):
642
641
  step_accumulates_one(context, self.micro_batch_number)
643
642
  return
@@ -685,7 +684,7 @@ class TrainerMon:
685
684
  self.module_struct[context.module_name].update(context.struct)
686
685
  return
687
686
 
688
- if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, \
687
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
689
688
  self.collect_times):
690
689
  step_accumulates_one(context, self.micro_batch_number)
691
690
  return
@@ -752,50 +751,10 @@ class TrainerMon:
752
751
  hooked_count += 1
753
752
  return hooked_count
754
753
 
755
- def _register_param_call_id(self, hook_name: str, key: str):
756
- """
757
- :param hook_name:
758
- :param key: str, '0:relu_0/output_grad'
759
- :return:
760
- """
761
- logger.debug(f"{hook_name} {key}: {self.call_id}")
762
- self.param_name_call_id[key] = self.call_id
763
- self.call_id += 1
764
-
765
754
  def _patch_grad_sync(self):
766
- # mindspore 暂不使用megatron
767
- def patch_sync(sync_grad_func):
768
- def wrapper(bucket):
769
- grad_dict = {}
770
- for param, name in self.param2name.items():
771
- if param not in bucket.params_list:
772
- continue
773
- grad = param.main_grad if self.params_have_main_grad else param.grad
774
- if grad is None:
775
- logger.warning(f"grad is None: {name}, maybe something wrong happened.")
776
- continue
777
- tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
778
- if tag is None:
779
- continue
780
- grad_dict[tag] = grad
781
- try:
782
- get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
783
- except Exception as e:
784
- logger.warning(f"An error occurred while generating weight grad metrics")
785
- out = sync_grad_func(bucket)
786
- return out
787
-
788
- return wrapper
789
-
790
- self.enable_megatron = False
791
-
792
755
  if not self.wg_distribution:
793
756
  return
794
-
795
- if self.enable_megatron:
796
- Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
797
- else:
798
- self._hook_weights()
757
+ self._hook_weights()
799
758
 
800
759
  def _hook_weights(self):
801
760
  context = self.grad_context
@@ -819,3 +778,96 @@ class TrainerMon:
819
778
  handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
820
779
  self.handles['wgrads'].append(handle)
821
780
  self.weight_hooked = True
781
+
782
+ def _is_target_param(self, param_name, param, prefix):
783
+ if not self.targets:
784
+ return True
785
+ squash_name = prefix + squash_param_name(param_name)
786
+ name = prefix + param_name
787
+ for target in self.targets.keys():
788
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
789
+ setattr(param, "zero_out_wgrad", True)
790
+ return True
791
+ return False
792
+
793
+ def _is_target_module(self, module_name, targets, vpp_stage):
794
+ if self.all_xy or self.print_struct:
795
+ return vpp_stage + squash_param_name(module_name)
796
+ for pattern in [
797
+ vpp_stage + squash_param_name(module_name),
798
+ vpp_stage + module_name,
799
+ ]:
800
+ if pattern in targets:
801
+ return pattern
802
+ return ""
803
+
804
+ def _register_param_call_id(self, hook_name: str, key: str):
805
+ """
806
+ :param hook_name:
807
+ :param key: str, '0:relu_0/output_grad'
808
+ :return:
809
+ """
810
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
811
+ self.param_name_call_id[key] = self.call_id
812
+ self.call_id += 1
813
+
814
+ def _remove_all_hooks(self):
815
+ # 清空hook handle
816
+ for handle in self.handles['xy']:
817
+ handle.remove()
818
+ self.handles['xy'].clear()
819
+ # 清空对应context缓存
820
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
821
+ fwd_context.reset()
822
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
823
+ bwd_context.reset()
824
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
825
+
826
+ for handle in self.handles['wgrads']:
827
+ handle.remove()
828
+ self.handles['wgrads'].clear()
829
+ self.weight_hooked = False
830
+
831
+ if self.optimizer_hooked:
832
+ for handle in self.handles['optimizer']:
833
+ handle.remove()
834
+ self.handles['optimizer'].clear()
835
+ for _, context in self.optimizer_context.items():
836
+ context.reset()
837
+ self.optimizer_hooked = False
838
+
839
+ for handle in self.handles['cc']:
840
+ handle.remove()
841
+ self.handles['cc'].clear()
842
+ api_register.restore_api()
843
+ for _, context in self.cc_context.items():
844
+ context.reset()
845
+
846
+ # 清空节点缓存
847
+ self.param2name.clear()
848
+ self.name2index.clear()
849
+ self.name2indices.clear()
850
+ self.name2param.clear()
851
+ self.duplicate_param.clear()
852
+ self.name2tag.clear()
853
+ self.module_struct.clear()
854
+ self.grad_accs.clear()
855
+
856
+ # 关闭采集状态
857
+ self.monitoring = False
858
+
859
+ def _remove_all_hooks_final(self, optimizer):
860
+ if self.dynamic_enable:
861
+ # 结束后自动重置dynamic_on为False等待用户手动开启
862
+ try:
863
+ config = load_json(self.config_file_path)
864
+ config['dynamic_on'] = False
865
+ save_json(self.config_file_path, config, indent=2)
866
+ config_timestamp = os.path.getmtime(self.config_file_path)
867
+ self.config_timestamp = config_timestamp
868
+ logger.info(
869
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
870
+ except Exception as e:
871
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
872
+ logger.info("Finish monitor")
873
+ self._remove_all_hooks()