mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +37 -47
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  9. msprobe/core/hook_manager.py +2 -16
  10. msprobe/core/service.py +5 -16
  11. msprobe/docs/01.installation.md +2 -0
  12. msprobe/docs/02.config_introduction.md +0 -13
  13. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  14. msprobe/docs/21.visualization_PyTorch.md +1 -1
  15. msprobe/docs/25.tool_function_introduction.md +0 -1
  16. msprobe/docs/32.ckpt_compare.md +5 -5
  17. msprobe/mindspore/monitor/module_hook.py +17 -20
  18. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  19. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  20. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  21. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  22. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  23. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  24. msprobe/pytorch/common/utils.py +0 -70
  25. msprobe/pytorch/debugger/debugger_config.py +0 -10
  26. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  27. msprobe/pytorch/hook_module/api_register.py +5 -1
  28. msprobe/pytorch/monitor/module_hook.py +16 -34
  29. msprobe/pytorch/pt_config.py +2 -51
  30. msprobe/pytorch/pytorch_service.py +2 -11
  31. msprobe/visualization/builder/graph_builder.py +2 -2
  32. msprobe/visualization/builder/graph_merger.py +13 -0
  33. msprobe/visualization/graph/graph.py +13 -9
  34. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  35. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  36. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  37. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  38. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  39. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  40. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  41. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  43. msprobe/pytorch/attl_manager.py +0 -65
  44. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
  45. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
  46. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
  47. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
@@ -51,8 +51,6 @@ from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
52
  from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
56
54
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
57
55
  ExecParams
58
56
 
@@ -90,27 +88,22 @@ seed_all()
90
88
 
91
89
  def run_ut(config):
92
90
  logger.info("start UT test")
93
- if config.online_config.is_online:
94
- logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
95
- logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
96
- else:
97
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
98
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
91
+
92
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
93
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
99
94
 
100
95
  if config.save_error_data:
101
96
  logger.info(f"UT task error_data will be saved in {config.error_data_path}")
102
97
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
103
98
 
104
- if config.online_config.is_online:
105
- run_api_online(config, compare)
106
- else:
107
- csv_df = read_csv(config.result_csv_path)
108
- try:
109
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
110
- except IndexError:
111
- logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
112
- api_name_set = set()
113
- run_api_offline(config, compare, api_name_set)
99
+
100
+ csv_df = read_csv(config.result_csv_path)
101
+ try:
102
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
+ except IndexError:
104
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
105
+ api_name_set = set()
106
+ run_api_offline(config, compare, api_name_set)
114
107
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
115
108
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
116
109
  change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -164,60 +157,6 @@ def run_api_offline(config, compare, api_name_set):
164
157
  gc.collect()
165
158
 
166
159
 
167
- def run_api_online(config, compare):
168
- attl = init_attl(config.online_config)
169
- dispatcher = ConsumerDispatcher(compare=compare)
170
- dispatcher.start(handle_func=run_torch_api_online, config=config)
171
-
172
- def tcp_communication_flow():
173
- while True:
174
- api_data = attl.recv()
175
- if api_data == 'STOP_':
176
- continue
177
- if api_data == 'KILL_':
178
- time.sleep(1)
179
- logger.info("==========接收到STOP信号==========")
180
- dispatcher.stop()
181
- attl.stop_serve()
182
- time.sleep(1)
183
- break
184
- if not isinstance(api_data, ApiData):
185
- continue
186
- api_full_name = api_data.name
187
- _, api_name = extract_basic_api_segments(api_full_name)
188
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
189
- continue
190
- if api_data.rank in config.online_config.rank_list:
191
- dispatcher.update_consume_queue(api_data)
192
-
193
- def shared_storage_communication_flow():
194
- flag_num = -1
195
- while True:
196
- api_data = attl.download()
197
- if api_data == "start":
198
- if flag_num == -1:
199
- flag_num += 1
200
- flag_num += 1
201
- if api_data == "end":
202
- flag_num -= 1
203
- if flag_num == 0:
204
- dispatcher.stop()
205
- break
206
- if not isinstance(api_data, ApiData):
207
- continue
208
- api_full_name = api_data.name
209
- _, api_name = extract_basic_api_segments(api_full_name)
210
- if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
211
- continue
212
- if api_data.rank in config.online_config.rank_list:
213
- dispatcher.update_consume_queue(api_data)
214
-
215
- if config.online_config.nfs_path:
216
- shared_storage_communication_flow()
217
- else:
218
- tcp_communication_flow()
219
-
220
-
221
160
  def blacklist_and_whitelist_filter(api_name, black_list, white_list):
222
161
  """
223
162
  run api(api_name) if api_name not in black_list and in white_list.
@@ -315,21 +254,6 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
315
254
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
316
255
 
317
256
 
318
- def run_torch_api_online(api_full_name, api_data, backward_content):
319
- in_fwd_data_list = []
320
- api_type, api_name = extract_basic_api_segments(api_full_name)
321
- args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
322
- in_fwd_data_list.append(args)
323
- in_fwd_data_list.append(kwargs)
324
- if kwargs.get("device"):
325
- del kwargs["device"]
326
-
327
- device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
328
- device_out = exec_api(device_exec_params)
329
- device_out = move2device_exec(device_out, "cpu")
330
- return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
331
-
332
-
333
257
  def check_need_grad(api_info_dict):
334
258
  need_grad = True
335
259
  if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
@@ -389,16 +313,6 @@ def initialize_save_error_data(error_data_path):
389
313
  return error_data_path
390
314
 
391
315
 
392
- def init_attl(config):
393
- """config: OnlineConfig"""
394
- attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
395
- connect_ip=config.host,
396
- connect_port=config.port,
397
- nfs_path=config.nfs_path,
398
- tls_path=config.tls_path))
399
- return attl
400
-
401
-
402
316
  def _run_ut_parser(parser):
403
317
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
404
318
  help="<Optional> The api param tool result file: generate from api param tool, "
@@ -481,38 +395,6 @@ def _run_ut(parser=None):
481
395
  _run_ut_parser(parser)
482
396
  args = parser.parse_args(sys.argv[1:])
483
397
  run_ut_command(args)
484
-
485
-
486
- def checked_online_config(online_config):
487
- if not online_config.is_online:
488
- return
489
- if not isinstance(online_config.is_online, bool):
490
- raise ValueError("is_online must be bool type")
491
- # rank_list
492
- if not isinstance(online_config.rank_list, list):
493
- raise ValueError("rank_list must be a list")
494
- if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
495
- raise ValueError("All elements in rank_list must be integers")
496
-
497
- # nfs_path
498
- if online_config.nfs_path:
499
- check_file_or_directory_path(online_config.nfs_path, isdir=True)
500
- return
501
- # tls_path
502
- if online_config.tls_path:
503
- check_file_or_directory_path(online_config.tls_path, isdir=True)
504
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
505
- check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
506
- check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
507
- crl_path = os.path.join(online_config.tls_path, "crl.pem")
508
- if os.path.exists(crl_path):
509
- check_file_or_directory_path(crl_path)
510
-
511
- # host and port
512
- if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
513
- raise Exception(f"host: {online_config.host} is invalid.")
514
- if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
515
- raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
516
398
 
517
399
 
518
400
  def run_ut_command(args):
@@ -525,7 +407,7 @@ def run_ut_command(args):
525
407
  else:
526
408
  checker_config = CheckerConfig()
527
409
 
528
- if not checker_config.is_online and not args.api_info_file:
410
+ if not args.api_info_file:
529
411
  logger.error("Please provide api_info_file for offline run ut.")
530
412
  raise Exception("Please provide api_info_file for offline run ut.")
531
413
 
@@ -588,8 +470,6 @@ def run_ut_command(args):
588
470
  global UT_ERROR_DATA_DIR
589
471
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
590
472
  error_data_path = initialize_save_error_data(error_data_path)
591
- online_config = checker_config.get_online_config()
592
- checked_online_config(online_config)
593
473
  config_params = {
594
474
  'forward_content': forward_content,
595
475
  'backward_content': backward_content,
@@ -337,56 +337,6 @@ def save_pt(tensor, filepath):
337
337
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
338
338
 
339
339
 
340
- class TypeCheckingUnpickler(pickle.Unpickler):
341
- """
342
- This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
343
- It overrides the find_class method to add type checking functionality.
344
- """
345
- allowed_types = [
346
- "str",
347
- "ApiData",
348
- "OrderedDict",
349
- "_rebuild_tensor_v2", # from torch.utils
350
- "_load_from_bytes" # from torch.storage
351
- ]
352
-
353
- def find_class(self, module, name):
354
- """
355
- Method to find the class of the object to be unpickled.
356
- Throws pickle.UnpicklingError If the object type is not in the allowed types list.
357
- """
358
- if name in self.allowed_types:
359
- return super().find_class(module, name)
360
- raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
361
-
362
-
363
- def save_pkl(tensor, filepath):
364
- """Save ApiData or str objection by pickle"""
365
- check_path_before_create(filepath)
366
- filepath = os.path.realpath(filepath)
367
- try:
368
- with FileOpen(filepath, 'wb') as f:
369
- pickle.dump(tensor, f)
370
- except Exception as e:
371
- logger.error("Save pt file failed, please check according possible error causes: "
372
- "1. out of disk space or disk error, "
373
- "2. no permission to write files, etc.")
374
- raise RuntimeError(f"save pt file {filepath} failed") from e
375
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
376
-
377
-
378
- def load_pkl(pt_path):
379
- """Load ApiData or str objection by pickle for accuracy_checker_online"""
380
- check_file_or_directory_path(pt_path)
381
- pt_path = os.path.realpath(pt_path)
382
- try:
383
- with FileOpen(pt_path, 'rb') as f:
384
- pt = TypeCheckingUnpickler(f).load()
385
- except Exception as e:
386
- raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
387
- return pt
388
-
389
-
390
340
  def is_recomputation():
391
341
  """Check if the current operation is in the re-computation phase.
392
342
 
@@ -471,23 +421,3 @@ def register_forward_hook(module, forward_hook):
471
421
  module.register_forward_hook(forward_hook, with_kwargs=True)
472
422
  else:
473
423
  module.register_forward_hook(forward_hook)
474
-
475
-
476
- def save_api_data(api_data):
477
- """Save data to io stream"""
478
- try:
479
- io_buff = io.BytesIO()
480
- torch.save(api_data, io_buff)
481
- except Exception as e:
482
- raise RuntimeError(f"save api_data to io_buff failed") from e
483
- return io_buff
484
-
485
-
486
- def load_api_data(api_data_bytes):
487
- """Load data from bytes stream"""
488
- try:
489
- buffer = io.BytesIO(api_data_bytes)
490
- buffer = torch.load(buffer, map_location="cpu")
491
- except Exception as e:
492
- raise RuntimeError(f"load api_data from bytes failed") from e
493
- return buffer
@@ -48,16 +48,6 @@ class DebuggerConfig:
48
48
  "max_sample": task_config.max_sample
49
49
  }
50
50
 
51
- self.online_run_ut = False
52
- if self.task == Const.TENSOR:
53
- # dump api tensor and collaborate with online run_ut
54
- self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
55
- self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
56
- self.tls_path = task_config.tls_path if task_config.tls_path else ""
57
- self.host = task_config.host if task_config.host else ""
58
- self.port = task_config.port if task_config.port else -1
59
- self.online_run_ut_recompute = task_config.online_run_ut_recompute \
60
- if isinstance(task_config.online_run_ut_recompute, bool) else False
61
51
 
62
52
  self.check()
63
53
  self._check_statistics_config(task_config)
@@ -63,9 +63,11 @@ def wrap_forward_with_hook_safety(module):
63
63
  except _StopRecomputationError as e:
64
64
  exception_output = None
65
65
  if len(module._forward_hooks.values()) > 0:
66
- # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook
67
- hook_fn = list(module._forward_hooks.values())[0]
68
- hook_fn(module, args, kwargs, exception_output)
66
+ # 仅执行msprobe的forward_hook, hook名称必然包含'ModuleProcesser.'
67
+ for hook_fn in module._forward_hooks.values():
68
+ if 'ModuleProcesser' in str(hook_fn):
69
+ hook_fn(module, args, kwargs, exception_output)
70
+ break
69
71
  raise e
70
72
 
71
73
  if torch_version_above_or_equal_21:
@@ -152,7 +154,13 @@ class ModuleProcesser:
152
154
  modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
153
155
  for index, modules_and_names in modules_and_names_with_index.items():
154
156
  model = models if index == "-1" else models[int(index)]
157
+
158
+ model_list = []
155
159
  for name, module in modules_and_names:
160
+ model_list.append((name, module))
161
+
162
+ is_verl = "verl" in sys.modules
163
+ for idx, (name, module) in enumerate(model_list):
156
164
  if recursive and module == model:
157
165
  continue
158
166
  if not is_torch_nn_module(module):
@@ -163,6 +171,13 @@ class ModuleProcesser:
163
171
  continue
164
172
  if module.__class__.__name__ == "FullyShardedDataParallel":
165
173
  continue
174
+
175
+ # verl 场景下跳过第一层和最后一层
176
+ if is_verl and (idx == 1 or idx == len(model_list) - 1):
177
+ logger.warning(f"The module {name} is the first or last layer in verl scenario, "
178
+ f"the data dump for this module will be skipped.")
179
+ continue
180
+
166
181
  setattr(module, 'msprobe_hook', True)
167
182
  module_index = (index + Const.SEP) if index != "-1" else ""
168
183
  prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
@@ -135,13 +135,17 @@ def redirect_wait():
135
135
  store_func = dist_data_collect_func.pop(args[0])
136
136
  store_func()
137
137
  return
138
+ remove_value = None
138
139
  for value in dist_batch_data_collect_func:
139
140
  if args[0] in value[0]:
140
141
  value[0].remove(args[0])
141
142
  if len(value[0]) == 0:
142
143
  store_func = value[1]
143
144
  store_func()
144
- return
145
+ remove_value = value
146
+ break
147
+ if remove_value:
148
+ dist_batch_data_collect_func.remove(remove_value)
145
149
 
146
150
  return wrapped_wait
147
151
 
@@ -48,12 +48,10 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
48
48
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
49
49
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
50
50
 
51
-
52
51
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
53
52
  if not torch_version_above_or_equal_2:
54
53
  raise ValueError("monitor require torch>=2.0")
55
54
 
56
-
57
55
  FORMAT_MAPPING = {
58
56
  MonitorConst.TENSORBOARD: SummaryWriterWithAD,
59
57
  MonitorConst.CSV: CSVWriterWithAD,
@@ -150,15 +148,11 @@ class GradContext:
150
148
  def __init__(self) -> None:
151
149
  self.pre = {}
152
150
  self.post = {}
153
- self.acc_metric = {}
154
- self.acc = {}
155
151
  self.actv = {}
156
152
 
157
153
  def reset(self):
158
154
  self.pre.clear()
159
155
  self.post.clear()
160
- self.acc_metric.clear()
161
- self.acc.clear()
162
156
  self.actv.clear()
163
157
 
164
158
 
@@ -510,18 +504,8 @@ class TrainerMon:
510
504
  if not self.wg_distribution:
511
505
  return {}, {}
512
506
 
513
- if self.weight_hooked:
514
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
515
-
516
507
  get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post)
517
- reduced_grad = self.grad_context.post
518
-
519
- if self.weight_hooked:
520
- unreduced_grad = self.grad_context.acc_metric
521
- else:
522
- unreduced_grad = self.grad_context.pre
523
-
524
- return reduced_grad, unreduced_grad
508
+ return self.grad_context.post, self.grad_context.pre
525
509
 
526
510
  def generate_xy_metrics(self):
527
511
  actv = {}
@@ -529,7 +513,6 @@ class TrainerMon:
529
513
  actv.update(fwd_context.actv)
530
514
 
531
515
  actv_grad = self.grad_context.actv
532
-
533
516
  return actv, actv_grad
534
517
 
535
518
  def reload_xy(self, xy_distribution=False):
@@ -607,11 +590,8 @@ class TrainerMon:
607
590
  if not self.wg_distribution:
608
591
  return
609
592
 
610
- if self.weight_hooked:
611
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
612
- use_micro_step=self.monitor_mbs_grad)
613
- else:
614
- self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
593
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
594
+ use_micro_step=self.monitor_mbs_grad)
615
595
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
616
596
 
617
597
  def hook_optimizer(self, optimizer):
@@ -732,9 +712,9 @@ class TrainerMon:
732
712
  # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
733
713
  if self.monitoring:
734
714
  module_rank_valid = not self.module_rank_list or (
735
- dist.is_initialized() and dist.get_rank() in self.module_rank_list)
715
+ dist.is_initialized() and dist.get_rank() in self.module_rank_list)
736
716
  step_condition = (context.step >= self.start_step and (
737
- context.step - self.start_step) % self.step_interval == 0)
717
+ context.step - self.start_step) % self.step_interval == 0)
738
718
  if module_rank_valid and step_condition:
739
719
  self.has_collect_times += 1
740
720
 
@@ -791,6 +771,7 @@ class TrainerMon:
791
771
  hook(optimizer, args, kwargs)
792
772
  step_final_hook(optimizer, args, kwargs)
793
773
  return out
774
+
794
775
  return wrapper
795
776
 
796
777
  optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
@@ -1013,11 +994,11 @@ class TrainerMon:
1013
994
  vpp_stage + module_name,
1014
995
  ]:
1015
996
  if pattern in l2_targets:
1016
- return pattern
997
+ return pattern
1017
998
  elif hook_name in ["linear_hook"]:
1018
999
  return vpp_stage + squash_param_name(module_name, self.squash_name)
1019
1000
  return ""
1020
-
1001
+
1021
1002
  def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
1022
1003
  if '_modules' not in module.__dict__:
1023
1004
  # nothing to hook
@@ -1151,7 +1132,7 @@ class TrainerMon:
1151
1132
  context.micro_step = 0
1152
1133
  context.step += 1
1153
1134
  return
1154
-
1135
+
1155
1136
  def stack_hook(module, args, kwargs, module_output, name):
1156
1137
  if module not in self.module_fwd_hook_context_by_module:
1157
1138
  self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
@@ -1221,7 +1202,7 @@ class TrainerMon:
1221
1202
  if self.monitor_mbs_grad:
1222
1203
  self._hook_weights()
1223
1204
  return
1224
-
1205
+
1225
1206
  self.optimizer_mon.patch_grad_sync(self)
1226
1207
 
1227
1208
  if self.enable_megatron or self.enable_deepspeed:
@@ -1281,6 +1262,7 @@ class TrainerMon:
1281
1262
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1282
1263
  out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
1283
1264
  return out
1265
+
1284
1266
  return wrapper
1285
1267
 
1286
1268
  logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
@@ -1294,10 +1276,9 @@ class TrainerMon:
1294
1276
  """
1295
1277
  遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
1296
1278
  """
1297
- context = self.grad_context
1298
1279
 
1299
1280
  @torch.no_grad
1300
- def param_hook(*args, context_dict, param, name):
1281
+ def param_hook(*args, param, name):
1301
1282
  key = name
1302
1283
  if self.monitor_mbs_grad:
1303
1284
  key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
@@ -1305,14 +1286,15 @@ class TrainerMon:
1305
1286
  key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
1306
1287
  self.register_param_call_id("param_hook", key)
1307
1288
  param.micro_step += 1
1308
-
1289
+ grad_dict = {}
1309
1290
  if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1310
1291
  if self.params_have_main_grad:
1311
1292
  grad = param.main_grad
1312
1293
  else:
1313
1294
  grad = param.grad
1314
- context_dict[key] = grad.clone()
1295
+ grad_dict[key] = grad.clone()
1315
1296
 
1297
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
1316
1298
  if param.micro_step == self.micro_batch_number:
1317
1299
  param.micro_step = 0
1318
1300
 
@@ -1322,7 +1304,7 @@ class TrainerMon:
1322
1304
  param_tmp = param.expand_as(param)
1323
1305
  grad_acc = param_tmp.grad_fn.next_functions[0][0]
1324
1306
  handle = grad_acc.register_hook(
1325
- partial(param_hook, context_dict=context.acc, param=param, name=name))
1307
+ partial(param_hook, param=param, name=name))
1326
1308
  self.grad_accs.append(grad_acc)
1327
1309
  self.handles['wgrads'].append(handle)
1328
1310
 
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
35
35
  class TensorConfig(BaseConfig):
36
36
  def __init__(self, json_config):
37
37
  super().__init__(json_config)
38
- self.online_run_ut = json_config.get("online_run_ut", False)
39
- self.nfs_path = json_config.get("nfs_path", "")
40
- self.host = json_config.get("host", "")
41
- self.port = json_config.get("port", -1)
42
- self.tls_path = json_config.get("tls_path", "./")
43
- self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
44
38
  self.check_config()
45
39
  self._check_summary_mode()
46
40
  self._check_file_format()
47
- if self.online_run_ut:
48
- self._check_online_run_ut()
41
+
49
42
 
50
43
  def _check_file_format(self):
51
44
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
52
45
  raise Exception("file_format is invalid")
53
46
 
54
- def _check_online_run_ut(self):
55
- if not isinstance(self.online_run_ut, bool):
56
- raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
57
-
58
- if not isinstance(self.online_run_ut_recompute, bool):
59
- raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
60
-
61
- if self.nfs_path:
62
- check_file_or_directory_path(self.nfs_path, isdir=True)
63
- return
64
-
65
- if self.tls_path:
66
- check_file_or_directory_path(self.tls_path, isdir=True)
67
- check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
68
- check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
69
- check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
- crl_path = os.path.join(self.tls_path, "crl.pem")
71
- if os.path.exists(crl_path):
72
- check_file_or_directory_path(crl_path)
73
-
74
- if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
75
- raise Exception(f"host: {self.host} is invalid.")
76
-
77
- if not isinstance(self.port, int) or not (0 < self.port <= 65535):
78
- raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
79
-
80
47
 
81
48
  class StatisticsConfig(BaseConfig):
82
49
  def __init__(self, json_config):
@@ -257,12 +224,7 @@ class RunUTConfig(BaseConfig):
257
224
  self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
258
225
  self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
259
226
  self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
260
- self.is_online = json_config.get("is_online", False)
261
- self.nfs_path = json_config.get("nfs_path", "")
262
- self.host = json_config.get("host", "")
263
- self.port = json_config.get("port", -1)
264
- self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
265
- self.tls_path = json_config.get("tls_path", "./")
227
+
266
228
  self.check_run_ut_config()
267
229
 
268
230
  @classmethod
@@ -280,22 +242,11 @@ class RunUTConfig(BaseConfig):
280
242
  if not os.path.exists(error_data_path):
281
243
  raise Exception("error_data_path: %s does not exist" % error_data_path)
282
244
 
283
- @classmethod
284
- def check_nfs_path_config(cls, nfs_path):
285
- if nfs_path:
286
- FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
287
-
288
- @classmethod
289
- def check_tls_path_config(cls, tls_path):
290
- if tls_path:
291
- FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
292
245
 
293
246
  def check_run_ut_config(self):
294
247
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
295
248
  RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
296
249
  RunUTConfig.check_error_data_path_config(self.error_data_path)
297
- RunUTConfig.check_nfs_path_config(self.nfs_path)
298
- RunUTConfig.check_tls_path_config(self.tls_path)
299
250
 
300
251
 
301
252
  class GradToolConfig(BaseConfig):
@@ -15,9 +15,8 @@
15
15
 
16
16
  from msprobe.core.common.utils import Const
17
17
  from msprobe.core.service import BaseService
18
- from msprobe.pytorch.attl_manager import ATTLManager
19
18
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
19
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
21
20
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
22
21
  from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
23
22
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
@@ -25,9 +24,6 @@ from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
25
24
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
26
25
  from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
27
26
 
28
- if torch_version_above_or_equal_2:
29
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
30
-
31
27
 
32
28
  class PytorchService(BaseService):
33
29
  @property
@@ -45,12 +41,10 @@ class PytorchService(BaseService):
45
41
  self.logger = logger
46
42
  self.api_register = get_api_register()
47
43
  self.module_processor = ModuleProcesser(self.data_collector.scope)
48
- self.attl_manager = ATTLManager(self.config)
49
- self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
44
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config)
50
45
  self.api_template = ApiTemplate
51
46
 
52
47
  def _register_hook(self):
53
- self.attl_manager.attl_init()
54
48
  if self._is_mix_level:
55
49
  register_optimizer_hook(self.data_collector)
56
50
 
@@ -65,9 +59,6 @@ class PytorchService(BaseService):
65
59
  self.module_processor.register_module_hook(self.model, self.build_hook)
66
60
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
67
61
 
68
- def _run_ut_dispatch(self, status):
69
- if torch_version_above_or_equal_2:
70
- run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
71
62
 
72
63
  def _reset_status(self):
73
64
  super()._reset_status()
@@ -298,8 +298,8 @@ class GraphBuilder:
298
298
  no_recompute_map = GraphBuilder._get_no_recompute_map(graph, id_prefixes)
299
299
  if not no_recompute_map:
300
300
  return
301
- # 深拷贝非重计算节点字典用于反向模式
302
- no_recompute_ids_b = copy.deepcopy(no_recompute_map)
301
+ # 拷贝非重计算节点字典用于反向模式
302
+ no_recompute_ids_b = {node_id: list(node_list) for node_id, node_list in no_recompute_map.items()}
303
303
 
304
304
  del_indexes = []
305
305
  for node_id, id_prefix in recompute_map.items():