mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/utils.py +26 -6
  9. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  10. msprobe/core/hook_manager.py +2 -16
  11. msprobe/core/service.py +5 -16
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +0 -13
  14. msprobe/docs/05.data_dump_PyTorch.md +1 -1
  15. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
  16. msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
  17. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  18. msprobe/docs/19.monitor.md +4 -4
  19. msprobe/docs/21.visualization_PyTorch.md +1 -1
  20. msprobe/docs/25.tool_function_introduction.md +0 -1
  21. msprobe/docs/32.ckpt_compare.md +5 -5
  22. msprobe/mindspore/monitor/module_hook.py +17 -20
  23. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  24. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  25. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  26. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  27. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  28. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  29. msprobe/pytorch/common/utils.py +0 -70
  30. msprobe/pytorch/debugger/debugger_config.py +0 -10
  31. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  32. msprobe/pytorch/hook_module/api_register.py +14 -3
  33. msprobe/pytorch/monitor/module_hook.py +16 -34
  34. msprobe/pytorch/pt_config.py +2 -51
  35. msprobe/pytorch/pytorch_service.py +10 -14
  36. msprobe/visualization/builder/graph_builder.py +2 -2
  37. msprobe/visualization/builder/graph_merger.py +13 -0
  38. msprobe/visualization/db_utils.py +42 -18
  39. msprobe/visualization/graph/graph.py +13 -9
  40. msprobe/visualization/graph_service.py +20 -10
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  43. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  44. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  45. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  46. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  47. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  48. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  49. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  50. msprobe/pytorch/attl_manager.py +0 -65
  51. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
  52. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
  53. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
  54. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,12 @@ from msprobe.core.common.const import FileCheckConst, Const
39
39
  from msprobe.core.common.utils import CompareException
40
40
 
41
41
 
42
- def split_json_file(input_file, num_splits, filter_api):
42
+ def split_json_file(input_file, num_splits, filter_api, device_id):
43
+ max_processes = len(device_id) * 8
44
+ if num_splits > max_processes:
45
+ logger.warning(f"A device supports a maximum of 8 processes. "
46
+ f"The total number of processes exceeds the limit, and it is set to {max_processes}.")
47
+ num_splits = max_processes
43
48
  forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
44
49
  input_dir = os.path.dirname(os.path.abspath(input_file))
45
50
  if filter_api:
@@ -88,7 +93,7 @@ def split_json_file(input_file, num_splits, filter_api):
88
93
  logger.error(f"File not found or could not be deleted: {file}")
89
94
  msg = 'ERROR: Split json file failed, please check the input file and try again.'
90
95
  raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
91
- return split_files, total_items
96
+ return split_files, total_items, num_splits
92
97
 
93
98
 
94
99
  def signal_handler(signum, frame):
@@ -127,7 +132,8 @@ def run_parallel_ut(config):
127
132
  def read_process_output(process):
128
133
  try:
129
134
  while True:
130
- if process.poll() is not None:
135
+ # 子进程标准输出流与进程本身状态是分开的,因此增加判断。子进程返回值非None表示子进程结束,标准输出为None表示结束。
136
+ if process.poll() is not None or process.stdout is None:
131
137
  break
132
138
  output = process.stdout.readline()
133
139
  if output == '':
@@ -175,12 +181,17 @@ def run_parallel_ut(config):
175
181
 
176
182
  try:
177
183
  for process in processes:
178
- process.communicate(timeout=None)
184
+ process.wait() # wait仅阻塞,不捕获标准输出和标准错误,原communicate不仅阻塞,而且捕获标准输出和标准错误
179
185
  except KeyboardInterrupt:
180
186
  logger.warning("Interrupted by user, terminating processes and cleaning up...")
181
187
  except Exception as e:
182
188
  logger.error(f"An unexpected error occurred: {e}")
183
189
  finally:
190
+ # 最后再更新一次进度条,避免因缓存写入等原因子进程结束而进度未刷新的问题
191
+ if wait_for_file_write_complete(config.result_csv_path):
192
+ result_file = read_csv(config.result_csv_path)
193
+ completed_items = len(result_file)
194
+ progress_bar.update(completed_items - progress_bar.n)
184
195
  if progress_bar.n < config.total_items:
185
196
  logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
186
197
  "the result CSV file will be utilized to resume the UT task.")
@@ -195,6 +206,22 @@ def run_parallel_ut(config):
195
206
  logger.error(f"An unexpected error occurred: {e}")
196
207
 
197
208
 
209
+ def wait_for_file_write_complete(file_path, timeout=3600):
210
+ last_size = 0
211
+ start_time = time.time() # 记录开始时间
212
+ while True:
213
+ current_size = os.path.getsize(file_path)
214
+ # 检查是否文件大小未变化
215
+ if current_size == last_size:
216
+ return True # 文件写入完成,返回 True
217
+ last_size = current_size
218
+ # 检查是否超时
219
+ if time.time() - start_time > timeout:
220
+ logger.error("write the result csv file timeout.")
221
+ return False # 超时,返回 False
222
+ time.sleep(0.1) # 适当的延时
223
+
224
+
198
225
  def prepare_config(args):
199
226
  api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
200
227
  ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
@@ -203,7 +230,9 @@ def prepare_config(args):
203
230
  create_directory(out_path)
204
231
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
205
232
  out_path = out_path_checker.common_check()
206
- split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
233
+ split_files, total_items, modified_num_splits = split_json_file(api_info, args.num_splits,
234
+ args.filter_api, args.device_id)
235
+ args.num_splits = modified_num_splits
207
236
  config_path = args.config_path if args.config_path else None
208
237
  if config_path:
209
238
  config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
52
  from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
56
54
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
57
55
  ExecParams
58
56
 
@@ -90,27 +88,22 @@ seed_all()
90
88
 
91
89
  def run_ut(config):
92
90
  logger.info("start UT test")
93
- if config.online_config.is_online:
94
- logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
95
- logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
96
- else:
97
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
98
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
91
+
92
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
93
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
99
94
 
100
95
  if config.save_error_data:
101
96
  logger.info(f"UT task error_data will be saved in {config.error_data_path}")
102
97
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
103
98
 
104
- if config.online_config.is_online:
105
- run_api_online(config, compare)
106
- else:
107
- csv_df = read_csv(config.result_csv_path)
108
- try:
109
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
110
- except IndexError:
111
- logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
112
- api_name_set = set()
113
- run_api_offline(config, compare, api_name_set)
99
+
100
+ csv_df = read_csv(config.result_csv_path)
101
+ try:
102
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
+ except IndexError:
104
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
105
+ api_name_set = set()
106
+ run_api_offline(config, compare, api_name_set)
114
107
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
115
108
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
116
109
  change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
164
157
  gc.collect()
165
158
 
166
159
 
167
- def run_api_online(config, compare):
168
- attl = init_attl(config.online_config)
169
- dispatcher = ConsumerDispatcher(compare=compare)
170
- dispatcher.start(handle_func=run_torch_api_online, config=config)
171
-
172
- def tcp_communication_flow():
173
- while True:
174
- api_data = attl.recv()
175
- if api_data == 'STOP_':
176
- continue
177
- if api_data == 'KILL_':
178
- time.sleep(1)
179
- logger.info("==========接收到STOP信号==========")
180
- dispatcher.stop()
181
- attl.stop_serve()
182
- time.sleep(1)
183
- break
184
- if not isinstance(api_data, ApiData):
185
- continue
186
- api_full_name = api_data.name
187
- _, api_name = extract_basic_api_segments(api_full_name)
188
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
189
- continue
190
- if api_data.rank in config.online_config.rank_list:
191
- dispatcher.update_consume_queue(api_data)
192
-
193
- def shared_storage_communication_flow():
194
- flag_num = -1
195
- while True:
196
- api_data = attl.download()
197
- if api_data == "start":
198
- if flag_num == -1:
199
- flag_num += 1
200
- flag_num += 1
201
- if api_data == "end":
202
- flag_num -= 1
203
- if flag_num == 0:
204
- dispatcher.stop()
205
- break
206
- if not isinstance(api_data, ApiData):
207
- continue
208
- api_full_name = api_data.name
209
- _, api_name = extract_basic_api_segments(api_full_name)
210
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
211
- continue
212
- if api_data.rank in config.online_config.rank_list:
213
- dispatcher.update_consume_queue(api_data)
214
-
215
- if config.online_config.nfs_path:
216
- shared_storage_communication_flow()
217
- else:
218
- tcp_communication_flow()
219
-
220
-
221
160
  def blacklist_and_whitelist_filter(api_name, black_list, white_list):
222
161
  """
223
162
  run api(api_name) if api_name not in black_list and in white_list.
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
315
254
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
316
255
 
317
256
 
318
- def run_torch_api_online(api_full_name, api_data, backward_content):
319
- in_fwd_data_list = []
320
- api_type, api_name = extract_basic_api_segments(api_full_name)
321
- args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
322
- in_fwd_data_list.append(args)
323
- in_fwd_data_list.append(kwargs)
324
- if kwargs.get("device"):
325
- del kwargs["device"]
326
-
327
- device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
328
- device_out = exec_api(device_exec_params)
329
- device_out = move2device_exec(device_out, "cpu")
330
- return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
331
-
332
-
333
257
  def check_need_grad(api_info_dict):
334
258
  need_grad = True
335
259
  if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
389
313
  return error_data_path
390
314
 
391
315
 
392
- def init_attl(config):
393
- """config: OnlineConfig"""
394
- attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
395
- connect_ip=config.host,
396
- connect_port=config.port,
397
- nfs_path=config.nfs_path,
398
- tls_path=config.tls_path))
399
- return attl
400
-
401
-
402
316
  def _run_ut_parser(parser):
403
317
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
404
318
  help="<Optional> The api param tool result file: generate from api param tool, "
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
481
395
  _run_ut_parser(parser)
482
396
  args = parser.parse_args(sys.argv[1:])
483
397
  run_ut_command(args)
484
-
485
-
486
- def checked_online_config(online_config):
487
- if not online_config.is_online:
488
- return
489
- if not isinstance(online_config.is_online, bool):
490
- raise ValueError("is_online must be bool type")
491
- # rank_list
492
- if not isinstance(online_config.rank_list, list):
493
- raise ValueError("rank_list must be a list")
494
- if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
495
- raise ValueError("All elements in rank_list must be integers")
496
-
497
- # nfs_path
498
- if online_config.nfs_path:
499
- check_file_or_directory_path(online_config.nfs_path, isdir=True)
500
- return
501
- # tls_path
502
- if online_config.tls_path:
503
- check_file_or_directory_path(online_config.tls_path, isdir=True)
504
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
505
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
506
- check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
507
- crl_path = os.path.join(online_config.tls_path, "crl.pem")
508
- if os.path.exists(crl_path):
509
- check_file_or_directory_path(crl_path)
510
-
511
- # host and port
512
- if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
513
- raise Exception(f"host: {online_config.host} is invalid.")
514
- if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
515
- raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
516
398
 
517
399
 
518
400
  def run_ut_command(args):
@@ -525,7 +407,7 @@ def run_ut_command(args):
525
407
  else:
526
408
  checker_config = CheckerConfig()
527
409
 
528
- if not checker_config.is_online and not args.api_info_file:
410
+ if not args.api_info_file:
529
411
  logger.error("Please provide api_info_file for offline run ut.")
530
412
  raise Exception("Please provide api_info_file for offline run ut.")
531
413
 
@@ -588,8 +470,6 @@ def run_ut_command(args):
588
470
  global UT_ERROR_DATA_DIR
589
471
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
590
472
  error_data_path = initialize_save_error_data(error_data_path)
591
- online_config = checker_config.get_online_config()
592
- checked_online_config(online_config)
593
473
  config_params = {
594
474
  'forward_content': forward_content,
595
475
  'backward_content': backward_content,
@@ -337,56 +337,6 @@ def save_pt(tensor, filepath):
337
337
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
338
338
 
339
339
 
340
- class TypeCheckingUnpickler(pickle.Unpickler):
341
- """
342
- This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
343
- It overrides the find_class method to add type checking functionality.
344
- """
345
- allowed_types = [
346
- "str",
347
- "ApiData",
348
- "OrderedDict",
349
- "_rebuild_tensor_v2", # from torch.utils
350
- "_load_from_bytes" # from torch.storage
351
- ]
352
-
353
- def find_class(self, module, name):
354
- """
355
- Method to find the class of the object to be unpickled.
356
- Throws pickle.UnpicklingError If the object type is not in the allowed types list.
357
- """
358
- if name in self.allowed_types:
359
- return super().find_class(module, name)
360
- raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
361
-
362
-
363
- def save_pkl(tensor, filepath):
364
- """Save ApiData or str objection by pickle"""
365
- check_path_before_create(filepath)
366
- filepath = os.path.realpath(filepath)
367
- try:
368
- with FileOpen(filepath, 'wb') as f:
369
- pickle.dump(tensor, f)
370
- except Exception as e:
371
- logger.error("Save pt file failed, please check according possible error causes: "
372
- "1. out of disk space or disk error, "
373
- "2. no permission to write files, etc.")
374
- raise RuntimeError(f"save pt file {filepath} failed") from e
375
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
376
-
377
-
378
- def load_pkl(pt_path):
379
- """Load ApiData or str objection by pickle for accuracy_checker_online"""
380
- check_file_or_directory_path(pt_path)
381
- pt_path = os.path.realpath(pt_path)
382
- try:
383
- with FileOpen(pt_path, 'rb') as f:
384
- pt = TypeCheckingUnpickler(f).load()
385
- except Exception as e:
386
- raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
387
- return pt
388
-
389
-
390
340
  def is_recomputation():
391
341
  """Check if the current operation is in the re-computation phase.
392
342
 
@@ -471,23 +421,3 @@ def register_forward_hook(module, forward_hook):
471
421
  module.register_forward_hook(forward_hook, with_kwargs=True)
472
422
  else:
473
423
  module.register_forward_hook(forward_hook)
474
-
475
-
476
- def save_api_data(api_data):
477
- """Save data to io stream"""
478
- try:
479
- io_buff = io.BytesIO()
480
- torch.save(api_data, io_buff)
481
- except Exception as e:
482
- raise RuntimeError(f"save api_data to io_buff failed") from e
483
- return io_buff
484
-
485
-
486
- def load_api_data(api_data_bytes):
487
- """Load data from bytes stream"""
488
- try:
489
- buffer = io.BytesIO(api_data_bytes)
490
- buffer = torch.load(buffer, map_location="cpu")
491
- except Exception as e:
492
- raise RuntimeError(f"load api_data from bytes failed") from e
493
- return buffer
@@ -48,16 +48,6 @@ class DebuggerConfig:
48
48
  "max_sample": task_config.max_sample
49
49
  }
50
50
 
51
- self.online_run_ut = False
52
- if self.task == Const.TENSOR:
53
- # dump api tensor and collaborate with online run_ut
54
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
55
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
56
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
57
- self.host = task_config.host if task_config.host else ""
58
- self.port = task_config.port if task_config.port else -1
59
- self.online_run_ut_recompute = task_config.online_run_ut_recompute \
60
- if isinstance(task_config.online_run_ut_recompute, bool) else False
61
51
 
62
52
  self.check()
63
53
  self._check_statistics_config(task_config)
@@ -63,9 +63,11 @@ def wrap_forward_with_hook_safety(module):
63
63
  except _StopRecomputationError as e:
64
64
  exception_output = None
65
65
  if len(module._forward_hooks.values()) > 0:
66
- # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook
67
- hook_fn = list(module._forward_hooks.values())[0]
68
- hook_fn(module, args, kwargs, exception_output)
66
+ # 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
67
+ for hook_fn in module._forward_hooks.values():
68
+ if 'ModuleProcesser' in str(hook_fn):
69
+ hook_fn(module, args, kwargs, exception_output)
70
+ break
69
71
  raise e
70
72
 
71
73
  if torch_version_above_or_equal_21:
@@ -152,7 +154,13 @@ class ModuleProcesser:
152
154
  modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
153
155
  for index, modules_and_names in modules_and_names_with_index.items():
154
156
  model = models if index == "-1" else models[int(index)]
157
+
158
+ model_list = []
155
159
  for name, module in modules_and_names:
160
+ model_list.append((name, module))
161
+
162
+ is_verl = "verl" in sys.modules
163
+ for idx, (name, module) in enumerate(model_list):
156
164
  if recursive and module == model:
157
165
  continue
158
166
  if not is_torch_nn_module(module):
@@ -163,6 +171,13 @@ class ModuleProcesser:
163
171
  continue
164
172
  if module.__class__.__name__ == "FullyShardedDataParallel":
165
173
  continue
174
+
175
+ # verl 场景下跳过第一层和最后一层
176
+ if is_verl and (idx == 1 or idx == len(model_list) - 1):
177
+ logger.warning(f"The module {name} is the first or last layer in verl scenario, "
178
+ f"the data dump for this module will be skipped.")
179
+ continue
180
+
166
181
  setattr(module, 'msprobe_hook', True)
167
182
  module_index = (index + Const.SEP) if index != "-1" else ""
168
183
  prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
@@ -22,6 +22,7 @@ import torch.distributed as dist
22
22
 
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.file_utils import load_yaml
25
+ from msprobe.core.common.runtime import Runtime
25
26
  from msprobe.core.data_dump.api_registry import ApiRegistry
26
27
  from msprobe.pytorch.common.log import logger
27
28
  from msprobe.pytorch.common.utils import (
@@ -91,6 +92,12 @@ _inner_used_api = {
91
92
  }
92
93
 
93
94
 
95
+ def reset_dist_collect_func():
96
+ global dist_data_collect_func, dist_batch_data_collect_func
97
+ dist_data_collect_func.clear()
98
+ dist_batch_data_collect_func.clear()
99
+
100
+
94
101
  @parameter_adapter
95
102
  def tensor_module_forward(module, *args, **kwargs):
96
103
  return module.api_func(*args, **kwargs)
@@ -114,9 +121,9 @@ def dist_module_forward(module, *args, **kwargs):
114
121
 
115
122
  return store_data
116
123
 
117
- if use_async_op_flag or module.api_name in ['isend', 'irecv']:
124
+ if Runtime.is_running and (use_async_op_flag or module.api_name in ['isend', 'irecv']):
118
125
  dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook)
119
- if module.api_name == 'batch_isend_irecv':
126
+ if Runtime.is_running and module.api_name == 'batch_isend_irecv':
120
127
  dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)])
121
128
  return handle
122
129
 
@@ -135,13 +142,17 @@ def redirect_wait():
135
142
  store_func = dist_data_collect_func.pop(args[0])
136
143
  store_func()
137
144
  return
145
+ remove_value = None
138
146
  for value in dist_batch_data_collect_func:
139
147
  if args[0] in value[0]:
140
148
  value[0].remove(args[0])
141
149
  if len(value[0]) == 0:
142
150
  store_func = value[1]
143
151
  store_func()
144
- return
152
+ remove_value = value
153
+ break
154
+ if remove_value:
155
+ dist_batch_data_collect_func.remove(remove_value)
145
156
 
146
157
  return wrapped_wait
147
158
 
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
48
48
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
49
49
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
50
50
 
51
-
52
51
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
53
52
  if not torch_version_above_or_equal_2:
54
53
  raise ValueError("monitor require torch>=2.0")
55
54
 
56
-
57
55
  FORMAT_MAPPING = {
58
56
  MonitorConst.TENSORBOARD: SummaryWriterWithAD,
59
57
  MonitorConst.CSV: CSVWriterWithAD,
@@ -150,15 +148,11 @@ class GradContext:
150
148
  def __init__(self) -> None:
151
149
  self.pre = {}
152
150
  self.post = {}
153
- self.acc_metric = {}
154
- self.acc = {}
155
151
  self.actv = {}
156
152
 
157
153
  def reset(self):
158
154
  self.pre.clear()
159
155
  self.post.clear()
160
- self.acc_metric.clear()
161
- self.acc.clear()
162
156
  self.actv.clear()
163
157
 
164
158
 
@@ -510,18 +504,8 @@ class TrainerMon:
510
504
  if not self.wg_distribution:
511
505
  return {}, {}
512
506
 
513
- if self.weight_hooked:
514
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
515
-
516
507
  get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
517
- reduced_grad = self.grad_context.post
518
-
519
- if self.weight_hooked:
520
- unreduced_grad = self.grad_context.acc_metric
521
- else:
522
- unreduced_grad = self.grad_context.pre
523
-
524
- return reduced_grad, unreduced_grad
508
+ return self.grad_context.post, self.grad_context.pre
525
509
 
526
510
  def generate_xy_metrics(self):
527
511
  actv = {}
@@ -529,7 +513,6 @@ class TrainerMon:
529
513
  actv.update(fwd_context.actv)
530
514
 
531
515
  actv_grad = self.grad_context.actv
532
-
533
516
  return actv, actv_grad
534
517
 
535
518
  def reload_xy(self, xy_distribution=False):
@@ -607,11 +590,8 @@ class TrainerMon:
607
590
  if not self.wg_distribution:
608
591
  return
609
592
 
610
- if self.weight_hooked:
611
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
612
- use_micro_step=self.monitor_mbs_grad)
613
- else:
614
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
593
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
594
+ use_micro_step=self.monitor_mbs_grad)
615
595
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
616
596
 
617
597
  def hook_optimizer(self, optimizer):
@@ -732,9 +712,9 @@ class TrainerMon:
732
712
  # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
733
713
  if self.monitoring:
734
714
  module_rank_valid = not self.module_rank_list or (
735
- dist.is_initialized() and dist.get_rank() in self.module_rank_list)
715
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
736
716
  step_condition = (context.step >= self.start_step and (
737
- context.step - self.start_step) % self.step_interval == 0)
717
+ context.step - self.start_step) % self.step_interval == 0)
738
718
  if module_rank_valid and step_condition:
739
719
  self.has_collect_times += 1
740
720
 
@@ -791,6 +771,7 @@ class TrainerMon:
791
771
  hook(optimizer, args, kwargs)
792
772
  step_final_hook(optimizer, args, kwargs)
793
773
  return out
774
+
794
775
  return wrapper
795
776
 
796
777
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
@@ -1013,11 +994,11 @@ class TrainerMon:
1013
994
  vpp_stage + module_name,
1014
995
  ]:
1015
996
  if pattern in l2_targets:
1016
- return pattern
997
+ return pattern
1017
998
  elif hook_name in ["linear_hook"]:
1018
999
  return vpp_stage + squash_param_name(module_name, self.squash_name)
1019
1000
  return ""
1020
-
1001
+
1021
1002
  def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
1022
1003
  if '_modules' not in module.__dict__:
1023
1004
  # nothing to hook
@@ -1151,7 +1132,7 @@ class TrainerMon:
1151
1132
  context.micro_step = 0
1152
1133
  context.step += 1
1153
1134
  return
1154
-
1135
+
1155
1136
  def stack_hook(module, args, kwargs, module_output, name):
1156
1137
  if module not in self.module_fwd_hook_context_by_module:
1157
1138
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -1221,7 +1202,7 @@ class TrainerMon:
1221
1202
  if self.monitor_mbs_grad:
1222
1203
  self._hook_weights()
1223
1204
  return
1224
-
1205
+
1225
1206
  self.optimizer_mon.patch_grad_sync(self)
1226
1207
 
1227
1208
  if self.enable_megatron or self.enable_deepspeed:
@@ -1281,6 +1262,7 @@ class TrainerMon:
1281
1262
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1282
1263
  out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
1283
1264
  return out
1265
+
1284
1266
  return wrapper
1285
1267
 
1286
1268
  logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
@@ -1294,10 +1276,9 @@ class TrainerMon:
1294
1276
  """
1295
1277
  遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
1296
1278
  """
1297
- context = self.grad_context
1298
1279
 
1299
1280
  @torch.no_grad
1300
- def param_hook(*args, context_dict, param, name):
1281
+ def param_hook(*args, param, name):
1301
1282
  key = name
1302
1283
  if self.monitor_mbs_grad:
1303
1284
  key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
@@ -1305,14 +1286,15 @@ class TrainerMon:
1305
1286
  key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
1306
1287
  self.register_param_call_id("param_hook", key)
1307
1288
  param.micro_step += 1
1308
-
1289
+ grad_dict = {}
1309
1290
  if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1310
1291
  if self.params_have_main_grad:
1311
1292
  grad = param.main_grad
1312
1293
  else:
1313
1294
  grad = param.grad
1314
- context_dict[key] = grad.clone()
1295
+ grad_dict[key] = grad.clone()
1315
1296
 
1297
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1316
1298
  if param.micro_step == self.micro_batch_number:
1317
1299
  param.micro_step = 0
1318
1300
 
@@ -1322,7 +1304,7 @@ class TrainerMon:
1322
1304
  param_tmp = param.expand_as(param)
1323
1305
  grad_acc = param_tmp.grad_fn.next_functions[0][0]
1324
1306
  handle = grad_acc.register_hook(
1325
- partial(param_hook, context_dict=context.acc, param=param, name=name))
1307
+ partial(param_hook, param=param, name=name))
1326
1308
  self.grad_accs.append(grad_acc)
1327
1309
  self.handles['wgrads'].append(handle)
1328
1310