mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +39 -40
  3. msprobe/README.md +7 -2
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +138 -32
  6. msprobe/core/common/framework_adapter.py +16 -6
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
  9. msprobe/core/compare/find_first/utils.py +1 -1
  10. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  11. msprobe/core/hook_manager.py +0 -1
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +1 -1
  14. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  15. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  16. msprobe/docs/21.visualization_PyTorch.md +1 -1
  17. msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
  18. msprobe/docs/32.ckpt_compare.md +5 -5
  19. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  20. msprobe/mindspore/compare/utils.py +1 -2
  21. msprobe/mindspore/monitor/module_hook.py +17 -20
  22. msprobe/msprobe.py +6 -4
  23. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  24. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  25. msprobe/pytorch/common/utils.py +2 -52
  26. msprobe/pytorch/compare/utils.py +1 -2
  27. msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
  28. msprobe/pytorch/dump/module_dump/module_processer.py +27 -6
  29. msprobe/pytorch/hook_module/api_register.py +11 -2
  30. msprobe/pytorch/monitor/module_hook.py +16 -34
  31. msprobe/pytorch/pt_config.py +6 -0
  32. msprobe/visualization/builder/graph_builder.py +3 -2
  33. msprobe/visualization/builder/graph_merger.py +13 -0
  34. msprobe/visualization/graph/graph.py +13 -9
  35. msprobe/visualization/utils.py +11 -1
  36. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
  37. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
  38. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
  39. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
  40. {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
@@ -708,3 +708,20 @@ def check_process_num(process_num):
708
708
  raise ValueError(f"process_num({process_num}) is not a positive integer")
709
709
  if process_num > Const.MAX_PROCESS_NUM:
710
710
  raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.")
711
+
712
+
713
+ def confirm(prompt, default=False):
714
+ if default is True:
715
+ prompt_suffix = " [Y/n] "
716
+ elif default is False:
717
+ prompt_suffix = " [y/N] "
718
+ else:
719
+ prompt_suffix = " [y/n] "
720
+
721
+ full_prompt = prompt + prompt_suffix
722
+
723
+ user_input = input(full_prompt).strip().lower()
724
+ if user_input in ['y', 'yes']:
725
+ return True
726
+ else:
727
+ return default
@@ -26,8 +26,6 @@ from msprobe.core.compare.utils import gen_api_batches
26
26
 
27
27
  cur_dir = os.path.dirname(os.path.realpath(__file__))
28
28
  diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
29
- ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
30
- ignore_list = load_yaml(ignore_op_list_yaml_path)
31
29
  thresholds = load_yaml(diff_threshold_yaml_path)
32
30
  cmp_metrics = thresholds.get('compare_metrics')
33
31
 
@@ -53,7 +51,7 @@ class FirstDiffAnalyze:
53
51
  return True
54
52
  return False
55
53
 
56
- def single_api_check(self, result_slice, header, api_name=None):
54
+ def single_api_check(self, result_slice, header):
57
55
  """
58
56
  单个api差异检查
59
57
 
@@ -67,18 +65,14 @@ class FirstDiffAnalyze:
67
65
  }
68
66
 
69
67
  column_indices = {name: idx for idx, name in enumerate(header)}
70
- output_idx = -1
68
+
71
69
  for line in result_slice:
72
70
  op_item = {
73
71
  column_name: line[column_indices[column_name]]
74
72
  for column_name in header
75
73
  }
76
74
  single_check_result['op_items'].append(op_item)
77
- if op_item['state'] != 'output':
78
- continue
79
- output_idx += 1
80
- if output_idx in ignore_list.get(api_name, []):
81
- continue
75
+
82
76
  # set is_same
83
77
  if self.mode_config.dump_mode == Const.MD5:
84
78
  if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
@@ -123,13 +117,7 @@ class FirstDiffAnalyze:
123
117
  with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
124
118
  for api_batch in api_batches:
125
119
  result_slice = result[api_batch.start: api_batch.params_grad_end_index]
126
- api_compo = api_batch.api_name.split('.')
127
- # suppose name is Tensor.MatMul.0.forward
128
- if len(api_compo) < 4:
129
- continue
130
- # get MatMul as api_name
131
- api_name = api_compo[-3]
132
- check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
120
+ check_result[api_batch.api_name] = self.single_api_check(result_slice, header)
133
121
  progress_bar.update(1)
134
122
 
135
123
  return check_result
@@ -182,7 +182,7 @@ def analyze_diff_in_group(nodes_group):
182
182
  input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
183
183
  # 如果有异常回溯计算节点找到异常来源
184
184
  # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
185
- get_compute_ops_from_comm_nodes(nodes_group)
185
+ get_compute_ops_from_comm_nodes(input_diff_nodes)
186
186
  # 筛选入参没问题但出参有问题的通信节点
187
187
  output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
188
188
  get_comm_ops(output_diff_nodes)
@@ -19,11 +19,11 @@ from tqdm import tqdm
19
19
  from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists, \
20
20
  check_file_or_directory_path
21
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.common.utils import confirm
22
23
  from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights
23
24
  from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC
24
25
 
25
26
 
26
-
27
27
  def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
28
28
  """Compare weights between two checkpoints using cosine similarity and L2 distance.
29
29
 
@@ -45,6 +45,11 @@ def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
45
45
  """
46
46
 
47
47
  # Load both checkpoints
48
+ if not confirm("You are using torch.load with weights_only is False, it may cause arbitrary code "
49
+ "execution. Do it only if you get the file from a trusted source. Input yes to continue, "
50
+ "otherwise exit", False):
51
+ logger.error("Insecure risks found and exit!")
52
+ raise Exception("Insecure risks found and exit!")
48
53
  check_file_or_directory_path(ckpt_path1, isdir=True)
49
54
  check_file_or_directory_path(ckpt_path2, isdir=True)
50
55
  check_path_before_create(output_path)
@@ -63,7 +63,6 @@ class BaseHookManager(ABC):
63
63
  def reset_status():
64
64
  BaseHookManager.inner_switch = defaultdict(bool)
65
65
  BaseHookManager.inner_api_count = defaultdict(int)
66
- BaseHookManager.hook_handle_dict.clear()
67
66
  BaseHookManager.params_grad_info.clear()
68
67
 
69
68
  @staticmethod
@@ -16,6 +16,8 @@ pip install mindstudio-probe
16
16
 
17
17
  | 版本 | 发布日期 |支持 PyTorch 版本|支持 MindSpore 版本| 下载链接 |校验码|
18
18
  |:-----:|:----------:|:--:|:--:|:----------------------------------------------------------------------------------------------------------------------------------:|:--:|
19
+ | 8.3.0 | 2025.10.30 |1.11/2.0/2.1/2.2/2.5/2.6/2.7|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.3/mindstudio_probe-8.3.0-py3-none-any.whl) |e933657b8ceb20774f924865d1f47978bd49cd1d9cec5fb60dec4f18802afbcc|
20
+ | 8.2.1 | 2025.9.29 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.1-py3-none-any.whl) |2152fadefec3d70148a910a54f2cfb6fccfa60b5ccd59f835f20289a7a5f4764|
19
21
  | 8.2.0 | 2025.9.03 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.0-py3-none-any.whl) |bbc1577d76754adf987069308177d3e0a04e36de9c7f22e75c34cf4ad0ce1af2|
20
22
  | 8.1.2 | 2025.8.01 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.2-py3-none-any.whl) |ff07bb81fddd3b8f3096d119ca1481bde8fdb24f10644def5250caad727448ab|
21
23
  | 8.1.1 | 2025.6.20 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.1-py3-none-any.whl) |2aad10a243575544d7feef552caf4d06aa93028488ebd0bbc9aa350379da859d|
@@ -145,7 +145,7 @@ PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L
145
145
  <tr><td>pert_mode</td><td>无标杆扰动因子,str 类型。可选参数:<br/> "improve_precision":对输入做升精度,默认值;<br/> "add_noise":对输入增加噪声;<br/> "no_change":不加扰动直接二次执行;<br/> "bit_noise":输入的末位比特翻转,MindSpore 场景不支持 BF16 类型的向量;<br/> "change_value":输入的张量首尾值调换;<br/> "to_cpu":在 CPU 等价执行(仅 PyTorch 场景支持)。<br/><b>配置示例</b>:"pert_mode": "improve_precision"。</td><td>否</td></tr>
146
146
  <tr><td>handler_type</td><td>处理类型,可选参数:<br/> "check":进行无标杆比对检查,默认值;<br/> "fix":将扰动后的 API 输出结果覆盖原始 API 输出结果,尝试将 Loss 曲线恢复正常,该模式下不支持预热功能与反向过程,且仅支持"improve_precision"、"to_cpu"( PyTorch 场景)两种扰动因子。<br/> <b>配置示例</b>:"handler_type": "check"。</td><td>否</td></tr>
147
147
  <tr><td>fuzz_level</td><td>无标杆数据 dump 级别,即选择比对结果文件应输出的表头属性,当前仅支持取值为:"L1"。输出结果详见 <a href="#161-无标杆比对数据存盘格式">1.6.1 无标杆比对数据存盘格式</a>。</td><td>否</td></tr>
148
- <tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"。<br/> <b>配置示例</b>:"fuzz_stage": "backward"。</td><td>否</td></tr>
148
+ <tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"。pytorch场景下,当 fuzz_stage 为 "backward" 时, list 参数不能为空。<br/> <b>配置示例</b>:"fuzz_stage": "backward"。</td><td>否</td></tr>
149
149
  <tr><td>if_preheat</td><td>预热功能(仅 PyTorch 场景支持),bool 类型。开启功能后工具可以根据每次迭代的输出调整精度算法的阈值,从而更准确地找出存在精度问题的 API。当"handler_type": "fix"时,不支持预热。可选参数:<br/> true(开启)或 false(关闭),默认关闭。<br/> <b>配置示例</b>:"if_preheat": "true"。</td><td>否</td></tr>
150
150
  <tr><td>preheat_step</td><td>开启预热的迭代数量(仅 PyTorch 场景支持),int 类型,默认值为 15。须配置 "if_preheat": "true"。</td><td>否</td></tr>
151
151
  <tr><td>max_sample</td><td>每个算子预热的采样次数的最大阈值(仅 PyTorch 场景支持),int 类型,默认值为 20。须配置 "if_preheat": "true"。</td><td>否</td></tr>
@@ -8,6 +8,8 @@
8
8
 
9
9
  依赖:CANN 包中的 msaccucmp 工具,需要安装 Ascend-CANN-toolkit,详见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0001.html)》。
10
10
 
11
+ **安全注意事项**: 工具使用时传入的msaccucmp.py文件路径参数请务必确保为CANN中自带的msaccucmp.py原始路径且文件内容未被篡改,工具会在解析数据时直接执行该文件,用户需自行保证该文件内容安全性。
12
+
11
13
  ## 2 数据解析操作指导
12
14
 
13
15
  ### 2.1 进入 parse 交互式界面
@@ -87,7 +87,7 @@ D-->config.json配置
87
87
  <tr><td>scope</td><td>否</td><td>自定义</td><td>需要通过指定算子名来限制算子插桩范围 如:["Torch.matmul.0.forward", "Tensor.pow.4.forward"]。</td></tr>
88
88
  <tr><td>list</td><td>否</td><td>自定义</td><td>需要通过指定算子类型来限制算子插桩范围 如:["relu"] 会匹配所有算子名中包含relu的算子。</td></tr>
89
89
  <tr><td rowspan="2">fuzz_stage</td><td rowspan="2">否</td><td>"forward"(默认)</td><td>需要进行算子<b>前向</b>计算的精度问题排查或<b>验证可疑算子。</b></td></tr>
90
- <tr><td>"backward"</td><td>需要进行算子<b>反向</b>计算的精度问题排查,不支持仅反向验证,前向验证包括反向。</td><td></td></tr>
90
+ <tr><td>"backward"</td><td>需要进行算子<b>反向</b>计算的精度问题排查,不支持仅反向验证,前向验证包括反向。必须设置list参数指定需要检测的算子(详见3.2 config.json配置章节), 指定的算子会暂存前向激活值,增加内存的占用。</td><td></td></tr>
91
91
  </table>
92
92
 
93
93
  #### 3.2.2 选择扰动因子
@@ -2,7 +2,7 @@
2
2
 
3
3
  分级可视化工具将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
4
4
 
5
- 工具支持PyTorch版本:2.1/2.2
5
+ 工具支持PyTorch版本:2.1/2.2/2.5/2.6/2.7
6
6
 
7
7
  ## 工具特性
8
8
 
@@ -6,9 +6,9 @@
6
6
 
7
7
  | 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | 加工具并使能 Md5 Dump (耗时) |
8
8
  |:--------:|:--------:|:-------------------:|:--------------------:|:--------------------:|
9
- | L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3 ms (膨胀10倍) |
10
- | L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0 ms (膨胀90倍) |
11
- | mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3 ms (膨胀112倍) |
9
+ | L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3 s (膨胀10倍) |
10
+ | L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0 s (膨胀90倍) |
11
+ | mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3 s (膨胀112倍) |
12
12
 
13
13
 
14
14
  ## "tensor"模式采集数据量参考基线
@@ -19,11 +19,11 @@ Megatron、MindSpeed的ckpt加载依赖megatron,请确保megatron在python环
19
19
  msprobe --framework pytorch config_check --compare path1 path2 -o output_path.json
20
20
  ```
21
21
 
22
- | 参数名 | 解释 | 是否必选 |
23
- |--------|-------|--------|
24
- | -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
25
- | -c 或 --compare | 2个ckpt的路径 | 是 |
26
- | -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
22
+ | 参数名 | 解释 | 是否必选 |
23
+ |--------|----------------------------------------------------------------------------------------------|--------|
24
+ | -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
25
+ | -c 或 --compare | 2个ckpt的路径。在ckpt传给工具加载前,用户需要确保ckpt是安全可信的,若ckpt来源官方有提供SHA256等校验值,用户必须要进行校验,以确保ckpt没有被篡改。 | 是 |
26
+ | -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
27
27
 
28
28
  Megatron-LM 和 MindSpeed 的 ckpt 目录结构如下:
29
29
 
@@ -45,7 +45,7 @@ API_INFO = 2
45
45
  FOUR_SEGMENT = 4
46
46
  FIVE_SEGMENT = 5
47
47
  DATA_NAME = "data_name"
48
- API_MAX_LENGTH = 30
48
+ API_MAX_LENGTH = 300
49
49
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
50
50
  DATAMODE_LIST = ["random_data", "real_data"]
51
51
  ITER_MAX_TIMES = 1000
@@ -26,8 +26,7 @@ def read_npy_data(dir_path, file_name):
26
26
  return None
27
27
 
28
28
  data_path = os.path.join(dir_path, file_name)
29
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
30
- FileCheckConst.NUMPY_SUFFIX, False)
29
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX)
31
30
  data_path = path_checker.common_check()
32
31
  data_value = load_npy(data_path)
33
32
  return data_value
@@ -163,15 +163,11 @@ class GradContext:
163
163
  def __init__(self) -> None:
164
164
  self.pre = {}
165
165
  self.post = {}
166
- self.acc_metric = {}
167
- self.acc = {}
168
166
  self.actv = {}
169
167
 
170
168
  def reset(self):
171
169
  self.pre.clear()
172
170
  self.post.clear()
173
- self.acc_metric.clear()
174
- self.acc.clear()
175
171
  self.actv.clear()
176
172
 
177
173
 
@@ -312,7 +308,6 @@ class TrainerMon:
312
308
  self.recording_l2_features = self.config.get('recording_l2_features', False)
313
309
  self.sa_order = self.config.get('sa_order', "s,b,h,d")
314
310
 
315
-
316
311
  if not self.cc_distribution.get('enable', False):
317
312
  self.cc_log_only = False
318
313
  else:
@@ -403,7 +398,7 @@ class TrainerMon:
403
398
  if self.monitoring:
404
399
  module_rank_valid = self.is_target_rank()
405
400
  step_condition = (context.step >= self.start_step and (
406
- context.step - self.start_step) % self.step_interval == 0)
401
+ context.step - self.start_step) % self.step_interval == 0)
407
402
  if module_rank_valid and step_condition:
408
403
  self.has_collect_times += 1
409
404
 
@@ -447,6 +442,7 @@ class TrainerMon:
447
442
  hook(optimizer, args, kwargs)
448
443
  step_final_hook(optimizer, args, kwargs)
449
444
  return out
445
+
450
446
  return wrapper
451
447
 
452
448
  if self.is_mindtorch:
@@ -541,7 +537,7 @@ class TrainerMon:
541
537
  if self.mv_distribution or self.ur_distribution or self.mg_direction:
542
538
  if self.is_mindtorch:
543
539
  context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
544
- context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
540
+ context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
545
541
  else:
546
542
  context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
547
543
 
@@ -564,7 +560,6 @@ class TrainerMon:
564
560
  context = self.optimizer_context[optimizer]
565
561
  self.generate_param_metrics(context, MonitorConst.POST_PARAM)
566
562
 
567
-
568
563
  if self.optimizer_hooked or not self.is_target_rank():
569
564
  return
570
565
 
@@ -577,7 +572,6 @@ class TrainerMon:
577
572
  if not self.wg_distribution:
578
573
  return
579
574
 
580
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
581
575
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
582
576
 
583
577
  def generate_param_map(self, tag, param_tensor):
@@ -671,7 +665,7 @@ class TrainerMon:
671
665
  if not self.wg_distribution:
672
666
  return
673
667
 
674
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
668
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
675
669
  use_micro_step=self.monitor_mbs_grad)
676
670
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
677
671
 
@@ -810,7 +804,7 @@ class TrainerMon:
810
804
  f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
811
805
  MonitorConst.ACTV, module_input))
812
806
  module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
813
- if isinstance(module_output, tuple) else module_output
807
+ if isinstance(module_output, tuple) else module_output
814
808
  tbtag_tensor_map.update(
815
809
  self.build_tbtag_tensor_map(
816
810
  f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
@@ -868,11 +862,13 @@ class TrainerMon:
868
862
  if version.parse(mindspore.__version__) >= version.parse('2.6.0'):
869
863
  def wrapper(module, args, kwargs, module_output):
870
864
  return fwd_hook_fun(module, args, kwargs, module_output, name)
865
+
871
866
  return module.register_forward_hook(wrapper, with_kwargs=True)
872
867
 
873
868
  else:
874
869
  def wrapper(module, args, module_output):
875
870
  return fwd_hook_fun(module, args, None, module_output, name)
871
+
876
872
  return module.register_forward_hook(wrapper)
877
873
 
878
874
  def extract_attention_feature_hook(module, args, kwargs, module_output, name):
@@ -880,7 +876,7 @@ class TrainerMon:
880
876
  if kwargs:
881
877
  kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
882
878
  module_input.extend(kwargs_tensors)
883
-
879
+
884
880
  if module not in self.feature_hook_context_by_module:
885
881
  self.feature_hook_context_by_module[module] = FeatureHookContext(name)
886
882
  context: FeatureHookContext = self.feature_hook_context_by_module[module]
@@ -890,7 +886,7 @@ class TrainerMon:
890
886
  logger.warning(
891
887
  "Calculate attention feature failed, the length of module_input in attention hook's module should "
892
888
  "be greater than or equal to 2.")
893
-
889
+
894
890
  q_h = module_input[0]
895
891
  k_h = module_input[1]
896
892
  qkt = cal_qkt(q_h, k_h, order=self.sa_order)
@@ -985,25 +981,26 @@ class TrainerMon:
985
981
  self._hook_weights()
986
982
 
987
983
  def _hook_weights(self):
988
- context = self.grad_context
989
984
 
990
985
  @_no_grad()
991
- def param_hook(grad, context_dict, param, name):
986
+ def param_hook(grad, param, name):
992
987
  key = name
993
988
  if self.monitor_mbs_grad:
994
989
  key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
995
990
  key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
996
991
  self.register_param_call_id("param_hook", key)
997
992
  param.micro_step += 1
998
-
993
+ grad_dict = {}
999
994
  if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1000
- context_dict[key] = grad
995
+ grad_dict[key] = grad
996
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
997
+
1001
998
  if param.micro_step == self.micro_batch_number:
1002
999
  param.micro_step = 0
1003
1000
 
1004
- def param_hook_wrapper(param_hook, context_dict, param, name):
1001
+ def param_hook_wrapper(param_hook, param, name):
1005
1002
  def wrapper(grad):
1006
- return param_hook(grad, context_dict, param, name)
1003
+ return param_hook(grad, param, name)
1007
1004
 
1008
1005
  return wrapper
1009
1006
 
@@ -1011,7 +1008,7 @@ class TrainerMon:
1011
1008
  for param, name in self.param2name.items():
1012
1009
  setattr(param, 'micro_step', 0)
1013
1010
  handle = param.register_hook(
1014
- param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
1011
+ param_hook_wrapper(param_hook, param=param, name=name))
1015
1012
  self.handles['wgrads'].append(handle)
1016
1013
  self.weight_hooked = True
1017
1014
 
msprobe/msprobe.py CHANGED
@@ -14,16 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import argparse
17
- import sys
18
17
  import importlib.util
18
+ import sys
19
19
 
20
20
  from msprobe.core.common.const import Const
21
+ from msprobe.core.common.file_utils import root_privilege_warning
21
22
  from msprobe.core.common.log import logger
22
- from msprobe.core.compare.utils import _compare_parser
23
23
  from msprobe.core.compare.compare_cli import compare_cli
24
24
  from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
25
- from msprobe.core.config_check.config_check_cli import _config_checking_parser, \
26
- _run_config_checking_command
25
+ from msprobe.core.compare.utils import _compare_parser
26
+ from msprobe.core.config_check.config_check_cli import _config_checking_parser, _run_config_checking_command
27
27
 
28
28
 
29
29
  def is_module_available(module_name):
@@ -64,6 +64,8 @@ def main():
64
64
  if len(sys.argv) < 4:
65
65
  parser.print_help()
66
66
  sys.exit(0)
67
+
68
+ root_privilege_warning()
67
69
  framework_args = parser.parse_args(sys.argv[1:3])
68
70
  if framework_args.framework == Const.PT_FRAMEWORK:
69
71
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
@@ -47,7 +47,7 @@ API_INFO = 2
47
47
  FOUR_SEGMENT = 4
48
48
  FIVE_SEGMENT = 5
49
49
  DATA_NAME = "data_name"
50
- API_MAX_LENGTH = 30
50
+ API_MAX_LENGTH = 300
51
51
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
52
  DATAMODE_LIST = ["random_data", "real_data"]
53
53
  ITER_MAX_TIMES = 1000
@@ -39,7 +39,12 @@ from msprobe.core.common.const import FileCheckConst, Const
39
39
  from msprobe.core.common.utils import CompareException
40
40
 
41
41
 
42
- def split_json_file(input_file, num_splits, filter_api):
42
+ def split_json_file(input_file, num_splits, filter_api, device_id):
43
+ max_processes = len(device_id) * 8
44
+ if num_splits > max_processes:
45
+ logger.warning(f"A device supports a maximum of 8 processes. "
46
+ f"The total number of processes exceeds the limit, and it is set to {max_processes}.")
47
+ num_splits = max_processes
43
48
  forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
44
49
  input_dir = os.path.dirname(os.path.abspath(input_file))
45
50
  if filter_api:
@@ -88,7 +93,7 @@ def split_json_file(input_file, num_splits, filter_api):
88
93
  logger.error(f"File not found or could not be deleted: {file}")
89
94
  msg = 'ERROR: Split json file failed, please check the input file and try again.'
90
95
  raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
91
- return split_files, total_items
96
+ return split_files, total_items, num_splits
92
97
 
93
98
 
94
99
  def signal_handler(signum, frame):
@@ -127,7 +132,8 @@ def run_parallel_ut(config):
127
132
  def read_process_output(process):
128
133
  try:
129
134
  while True:
130
- if process.poll() is not None:
135
+ # 子进程标准输出流与进程本身状态是分开的,因此增加判断。子进程返回值非None表示子进程结束,标准输出为None表示结束。
136
+ if process.poll() is not None or process.stdout is None:
131
137
  break
132
138
  output = process.stdout.readline()
133
139
  if output == '':
@@ -175,12 +181,17 @@ def run_parallel_ut(config):
175
181
 
176
182
  try:
177
183
  for process in processes:
178
- process.communicate(timeout=None)
184
+ process.wait() # wait仅阻塞,不捕获标准输出和标准错误,原communicate不仅阻塞,而且捕获标准输出和标准错误
179
185
  except KeyboardInterrupt:
180
186
  logger.warning("Interrupted by user, terminating processes and cleaning up...")
181
187
  except Exception as e:
182
188
  logger.error(f"An unexpected error occurred: {e}")
183
189
  finally:
190
+ # 最后再更新一次进度条,避免因缓存写入等原因子进程结束而进度未刷新的问题
191
+ if wait_for_file_write_complete(config.result_csv_path):
192
+ result_file = read_csv(config.result_csv_path)
193
+ completed_items = len(result_file)
194
+ progress_bar.update(completed_items - progress_bar.n)
184
195
  if progress_bar.n < config.total_items:
185
196
  logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
186
197
  "the result CSV file will be utilized to resume the UT task.")
@@ -195,6 +206,22 @@ def run_parallel_ut(config):
195
206
  logger.error(f"An unexpected error occurred: {e}")
196
207
 
197
208
 
209
+ def wait_for_file_write_complete(file_path, timeout=3600):
210
+ last_size = 0
211
+ start_time = time.time() # 记录开始时间
212
+ while True:
213
+ current_size = os.path.getsize(file_path)
214
+ # 检查是否文件大小未变化
215
+ if current_size == last_size:
216
+ return True # 文件写入完成,返回 True
217
+ last_size = current_size
218
+ # 检查是否超时
219
+ if time.time() - start_time > timeout:
220
+ logger.error("write the result csv file timeout.")
221
+ return False # 超时,返回 False
222
+ time.sleep(0.1) # 适当的延时
223
+
224
+
198
225
  def prepare_config(args):
199
226
  api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
200
227
  ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
@@ -203,7 +230,9 @@ def prepare_config(args):
203
230
  create_directory(out_path)
204
231
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
205
232
  out_path = out_path_checker.common_check()
206
- split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
233
+ split_files, total_items, modified_num_splits = split_json_file(api_info, args.num_splits,
234
+ args.filter_api, args.device_id)
235
+ args.num_splits = modified_num_splits
207
236
  config_path = args.config_path if args.config_path else None
208
237
  if config_path:
209
238
  config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
@@ -39,7 +39,6 @@ except ImportError:
39
39
  else:
40
40
  is_gpu = False
41
41
 
42
-
43
42
  torch_without_guard_version = torch.__version__ >= '2.1'
44
43
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
45
44
 
@@ -338,56 +337,6 @@ def save_pt(tensor, filepath):
338
337
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
339
338
 
340
339
 
341
- class TypeCheckingUnpickler(pickle.Unpickler):
342
- """
343
- This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
344
- It overrides the find_class method to add type checking functionality.
345
- """
346
- allowed_types = [
347
- "str",
348
- "ApiData",
349
- "OrderedDict",
350
- "_rebuild_tensor_v2", # from torch.utils
351
- "_load_from_bytes" # from torch.storage
352
- ]
353
-
354
- def find_class(self, module, name):
355
- """
356
- Method to find the class of the object to be unpickled.
357
- Throws pickle.UnpicklingError If the object type is not in the allowed types list.
358
- """
359
- if name in self.allowed_types:
360
- return super().find_class(module, name)
361
- raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
362
-
363
-
364
- def save_pkl(tensor, filepath):
365
- """Save ApiData or str objection by pickle"""
366
- check_path_before_create(filepath)
367
- filepath = os.path.realpath(filepath)
368
- try:
369
- with FileOpen(filepath, 'wb') as f:
370
- pickle.dump(tensor, f)
371
- except Exception as e:
372
- logger.error("Save pt file failed, please check according possible error causes: "
373
- "1. out of disk space or disk error, "
374
- "2. no permission to write files, etc.")
375
- raise RuntimeError(f"save pt file {filepath} failed") from e
376
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
377
-
378
-
379
- def load_pkl(pt_path):
380
- """Load ApiData or str objection by pickle for accuracy_checker_online"""
381
- check_file_or_directory_path(pt_path)
382
- pt_path = os.path.realpath(pt_path)
383
- try:
384
- with FileOpen(pt_path, 'rb') as f:
385
- pt = TypeCheckingUnpickler(f).load()
386
- except Exception as e:
387
- raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
388
- return pt
389
-
390
-
391
340
  def is_recomputation():
392
341
  """Check if the current operation is in the re-computation phase.
393
342
 
@@ -416,7 +365,8 @@ def is_recomputation():
416
365
 
417
366
  # Identify indices in the call stack where the specific function is being executed
418
367
  for idx, frame_info in enumerate(call_stack):
419
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
368
+ if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
369
+ "megatron" in frame_info.filename):
420
370
  backward_function_indices.append(idx)
421
371
 
422
372
  # Check if the execution is within 'torch/autograd/function.py' file
@@ -27,8 +27,7 @@ def read_pt_data(dir_path, file_name):
27
27
  return None
28
28
 
29
29
  data_path = os.path.join(dir_path, file_name)
30
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
31
- FileCheckConst.PT_SUFFIX, False)
30
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
32
31
  data_path = path_checker.common_check()
33
32
  try:
34
33
  # detach because numpy can not process gradient information
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from functools import wraps
17
+ from typing import Any, Callable
17
18
 
18
19
  import torch
19
20
  from torch.utils.hooks import BackwardHook
@@ -21,6 +22,9 @@ from torch.utils.hooks import BackwardHook
21
22
  from msprobe.core.common.const import Const
22
23
  from msprobe.core.common.decorator import recursion_depth_decorator
23
24
  from msprobe.pytorch.common.log import logger
25
+ from msprobe.pytorch.hook_module.api_register import get_api_register
26
+
27
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
28
 
25
29
 
26
30
  def wrap_setup_backward_hook(func):
@@ -92,3 +96,23 @@ def wrap_setup_backward_hook(func):
92
96
  def wrap_setup_input_output_hook():
93
97
  BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
94
98
  BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
99
+
100
+
101
+ def get_apply_func_wrapper(original_func: Callable) -> Callable:
102
+ @wraps(original_func)
103
+ def wrapped_apply(*args, **kwargs) -> Any:
104
+ api_register = get_api_register()
105
+ if api_register:
106
+ api_register.restore_inner_used_api()
107
+ result = original_func(*args, **kwargs)
108
+ if api_register:
109
+ api_register.register_inner_used_api()
110
+ return result
111
+
112
+ return wrapped_apply
113
+
114
+
115
+ def wrap_backward_hook_function_apply():
116
+ if torch_version_above_or_equal_2:
117
+ original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
118
+ torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)