mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +39 -40
- msprobe/README.md +7 -2
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +138 -32
- msprobe/core/common/framework_adapter.py +16 -6
- msprobe/core/common/utils.py +17 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
- msprobe/core/compare/find_first/utils.py +1 -1
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +0 -1
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/common/utils.py +2 -52
- msprobe/pytorch/compare/utils.py +1 -2
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +27 -6
- msprobe/pytorch/hook_module/api_register.py +11 -2
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +6 -0
- msprobe/visualization/builder/graph_builder.py +3 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/visualization/utils.py +11 -1
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -708,3 +708,20 @@ def check_process_num(process_num):
|
|
|
708
708
|
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
709
709
|
if process_num > Const.MAX_PROCESS_NUM:
|
|
710
710
|
raise ValueError(f"The maximum supported process_num is {Const.MAX_PROCESS_NUM}, current value: {process_num}.")
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
def confirm(prompt, default=False):
|
|
714
|
+
if default is True:
|
|
715
|
+
prompt_suffix = " [Y/n] "
|
|
716
|
+
elif default is False:
|
|
717
|
+
prompt_suffix = " [y/N] "
|
|
718
|
+
else:
|
|
719
|
+
prompt_suffix = " [y/n] "
|
|
720
|
+
|
|
721
|
+
full_prompt = prompt + prompt_suffix
|
|
722
|
+
|
|
723
|
+
user_input = input(full_prompt).strip().lower()
|
|
724
|
+
if user_input in ['y', 'yes']:
|
|
725
|
+
return True
|
|
726
|
+
else:
|
|
727
|
+
return default
|
|
@@ -26,8 +26,6 @@ from msprobe.core.compare.utils import gen_api_batches
|
|
|
26
26
|
|
|
27
27
|
cur_dir = os.path.dirname(os.path.realpath(__file__))
|
|
28
28
|
diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
|
|
29
|
-
ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
|
|
30
|
-
ignore_list = load_yaml(ignore_op_list_yaml_path)
|
|
31
29
|
thresholds = load_yaml(diff_threshold_yaml_path)
|
|
32
30
|
cmp_metrics = thresholds.get('compare_metrics')
|
|
33
31
|
|
|
@@ -53,7 +51,7 @@ class FirstDiffAnalyze:
|
|
|
53
51
|
return True
|
|
54
52
|
return False
|
|
55
53
|
|
|
56
|
-
def single_api_check(self, result_slice, header
|
|
54
|
+
def single_api_check(self, result_slice, header):
|
|
57
55
|
"""
|
|
58
56
|
单个api差异检查
|
|
59
57
|
|
|
@@ -67,18 +65,14 @@ class FirstDiffAnalyze:
|
|
|
67
65
|
}
|
|
68
66
|
|
|
69
67
|
column_indices = {name: idx for idx, name in enumerate(header)}
|
|
70
|
-
|
|
68
|
+
|
|
71
69
|
for line in result_slice:
|
|
72
70
|
op_item = {
|
|
73
71
|
column_name: line[column_indices[column_name]]
|
|
74
72
|
for column_name in header
|
|
75
73
|
}
|
|
76
74
|
single_check_result['op_items'].append(op_item)
|
|
77
|
-
|
|
78
|
-
continue
|
|
79
|
-
output_idx += 1
|
|
80
|
-
if output_idx in ignore_list.get(api_name, []):
|
|
81
|
-
continue
|
|
75
|
+
|
|
82
76
|
# set is_same
|
|
83
77
|
if self.mode_config.dump_mode == Const.MD5:
|
|
84
78
|
if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
|
|
@@ -123,13 +117,7 @@ class FirstDiffAnalyze:
|
|
|
123
117
|
with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
|
|
124
118
|
for api_batch in api_batches:
|
|
125
119
|
result_slice = result[api_batch.start: api_batch.params_grad_end_index]
|
|
126
|
-
|
|
127
|
-
# suppose name is Tensor.MatMul.0.forward
|
|
128
|
-
if len(api_compo) < 4:
|
|
129
|
-
continue
|
|
130
|
-
# get MatMul as api_name
|
|
131
|
-
api_name = api_compo[-3]
|
|
132
|
-
check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
|
|
120
|
+
check_result[api_batch.api_name] = self.single_api_check(result_slice, header)
|
|
133
121
|
progress_bar.update(1)
|
|
134
122
|
|
|
135
123
|
return check_result
|
|
@@ -182,7 +182,7 @@ def analyze_diff_in_group(nodes_group):
|
|
|
182
182
|
input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
|
|
183
183
|
# 如果有异常回溯计算节点找到异常来源
|
|
184
184
|
# 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
|
|
185
|
-
get_compute_ops_from_comm_nodes(
|
|
185
|
+
get_compute_ops_from_comm_nodes(input_diff_nodes)
|
|
186
186
|
# 筛选入参没问题但出参有问题的通信节点
|
|
187
187
|
output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
|
|
188
188
|
get_comm_ops(output_diff_nodes)
|
|
@@ -19,11 +19,11 @@ from tqdm import tqdm
|
|
|
19
19
|
from msprobe.core.common.file_utils import save_json, check_path_before_create, check_path_not_exists, \
|
|
20
20
|
check_file_or_directory_path
|
|
21
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.common.utils import confirm
|
|
22
23
|
from msprobe.core.config_check.ckpt_compare.megatron_loader import load_megatron_weights
|
|
23
24
|
from msprobe.core.config_check.ckpt_compare.metrics import METRIC_FUNC
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
|
|
27
27
|
def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
|
|
28
28
|
"""Compare weights between two checkpoints using cosine similarity and L2 distance.
|
|
29
29
|
|
|
@@ -45,6 +45,11 @@ def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict:
|
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
# Load both checkpoints
|
|
48
|
+
if not confirm("You are using torch.load with weights_only is False, it may cause arbitrary code "
|
|
49
|
+
"execution. Do it only if you get the file from a trusted source. Input yes to continue, "
|
|
50
|
+
"otherwise exit", False):
|
|
51
|
+
logger.error("Insecure risks found and exit!")
|
|
52
|
+
raise Exception("Insecure risks found and exit!")
|
|
48
53
|
check_file_or_directory_path(ckpt_path1, isdir=True)
|
|
49
54
|
check_file_or_directory_path(ckpt_path2, isdir=True)
|
|
50
55
|
check_path_before_create(output_path)
|
msprobe/core/hook_manager.py
CHANGED
|
@@ -63,7 +63,6 @@ class BaseHookManager(ABC):
|
|
|
63
63
|
def reset_status():
|
|
64
64
|
BaseHookManager.inner_switch = defaultdict(bool)
|
|
65
65
|
BaseHookManager.inner_api_count = defaultdict(int)
|
|
66
|
-
BaseHookManager.hook_handle_dict.clear()
|
|
67
66
|
BaseHookManager.params_grad_info.clear()
|
|
68
67
|
|
|
69
68
|
@staticmethod
|
msprobe/docs/01.installation.md
CHANGED
|
@@ -16,6 +16,8 @@ pip install mindstudio-probe
|
|
|
16
16
|
|
|
17
17
|
| 版本 | 发布日期 |支持 PyTorch 版本|支持 MindSpore 版本| 下载链接 |校验码|
|
|
18
18
|
|:-----:|:----------:|:--:|:--:|:----------------------------------------------------------------------------------------------------------------------------------:|:--:|
|
|
19
|
+
| 8.3.0 | 2025.10.30 |1.11/2.0/2.1/2.2/2.5/2.6/2.7|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.3/mindstudio_probe-8.3.0-py3-none-any.whl) |e933657b8ceb20774f924865d1f47978bd49cd1d9cec5fb60dec4f18802afbcc|
|
|
20
|
+
| 8.2.1 | 2025.9.29 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.1-py3-none-any.whl) |2152fadefec3d70148a910a54f2cfb6fccfa60b5ccd59f835f20289a7a5f4764|
|
|
19
21
|
| 8.2.0 | 2025.9.03 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.0-py3-none-any.whl) |bbc1577d76754adf987069308177d3e0a04e36de9c7f22e75c34cf4ad0ce1af2|
|
|
20
22
|
| 8.1.2 | 2025.8.01 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.2-py3-none-any.whl) |ff07bb81fddd3b8f3096d119ca1481bde8fdb24f10644def5250caad727448ab|
|
|
21
23
|
| 8.1.1 | 2025.6.20 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.1-py3-none-any.whl) |2aad10a243575544d7feef552caf4d06aa93028488ebd0bbc9aa350379da859d|
|
|
@@ -145,7 +145,7 @@ PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L
|
|
|
145
145
|
<tr><td>pert_mode</td><td>无标杆扰动因子,str 类型。可选参数:<br/> "improve_precision":对输入做升精度,默认值;<br/> "add_noise":对输入增加噪声;<br/> "no_change":不加扰动直接二次执行;<br/> "bit_noise":输入的末位比特翻转,MindSpore 场景不支持 BF16 类型的向量;<br/> "change_value":输入的张量首尾值调换;<br/> "to_cpu":在 CPU 等价执行(仅 PyTorch 场景支持)。<br/><b>配置示例</b>:"pert_mode": "improve_precision"。</td><td>否</td></tr>
|
|
146
146
|
<tr><td>handler_type</td><td>处理类型,可选参数:<br/> "check":进行无标杆比对检查,默认值;<br/> "fix":将扰动后的 API 输出结果覆盖原始 API 输出结果,尝试将 Loss 曲线恢复正常,该模式下不支持预热功能与反向过程,且仅支持"improve_precision"、"to_cpu"( PyTorch 场景)两种扰动因子。<br/> <b>配置示例</b>:"handler_type": "check"。</td><td>否</td></tr>
|
|
147
147
|
<tr><td>fuzz_level</td><td>无标杆数据 dump 级别,即选择比对结果文件应输出的表头属性,当前仅支持取值为:"L1"。输出结果详见 <a href="#161-无标杆比对数据存盘格式">1.6.1 无标杆比对数据存盘格式</a>。</td><td>否</td></tr>
|
|
148
|
-
<tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"
|
|
148
|
+
<tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"。pytorch场景下,当 fuzz_stage 为 "backward" 时, list 参数不能为空。<br/> <b>配置示例</b>:"fuzz_stage": "backward"。</td><td>否</td></tr>
|
|
149
149
|
<tr><td>if_preheat</td><td>预热功能(仅 PyTorch 场景支持),bool 类型。开启功能后工具可以根据每次迭代的输出调整精度算法的阈值,从而更准确地找出存在精度问题的 API。当"handler_type": "fix"时,不支持预热。可选参数:<br/> true(开启)或 false(关闭),默认关闭。<br/> <b>配置示例</b>:"if_preheat": "true"。</td><td>否</td></tr>
|
|
150
150
|
<tr><td>preheat_step</td><td>开启预热的迭代数量(仅 PyTorch 场景支持),int 类型,默认值为 15。须配置 "if_preheat": "true"。</td><td>否</td></tr>
|
|
151
151
|
<tr><td>max_sample</td><td>每个算子预热的采样次数的最大阈值(仅 PyTorch 场景支持),int 类型,默认值为 20。须配置 "if_preheat": "true"。</td><td>否</td></tr>
|
|
@@ -8,6 +8,8 @@
|
|
|
8
8
|
|
|
9
9
|
依赖:CANN 包中的 msaccucmp 工具,需要安装 Ascend-CANN-toolkit,详见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0001.html)》。
|
|
10
10
|
|
|
11
|
+
**安全注意事项**: 工具使用时传入的msaccucmp.py文件路径参数请务必确保为CANN中自带的msaccucmp.py原始路径且文件内容未被篡改,工具会在解析数据时直接执行该文件,用户需自行保证该文件内容安全性。
|
|
12
|
+
|
|
11
13
|
## 2 数据解析操作指导
|
|
12
14
|
|
|
13
15
|
### 2.1 进入 parse 交互式界面
|
|
@@ -87,7 +87,7 @@ D-->config.json配置
|
|
|
87
87
|
<tr><td>scope</td><td>否</td><td>自定义</td><td>需要通过指定算子名来限制算子插桩范围 如:["Torch.matmul.0.forward", "Tensor.pow.4.forward"]。</td></tr>
|
|
88
88
|
<tr><td>list</td><td>否</td><td>自定义</td><td>需要通过指定算子类型来限制算子插桩范围 如:["relu"] 会匹配所有算子名中包含relu的算子。</td></tr>
|
|
89
89
|
<tr><td rowspan="2">fuzz_stage</td><td rowspan="2">否</td><td>"forward"(默认)</td><td>需要进行算子<b>前向</b>计算的精度问题排查或<b>验证可疑算子。</b></td></tr>
|
|
90
|
-
<tr><td>"backward"</td><td>需要进行算子<b>反向</b
|
|
90
|
+
<tr><td>"backward"</td><td>需要进行算子<b>反向</b>计算的精度问题排查,不支持仅反向验证,前向验证包括反向。必须设置list参数指定需要检测的算子(详见3.2 config.json配置章节), 指定的算子会暂存前向激活值,增加内存的占用。</td><td></td></tr>
|
|
91
91
|
</table>
|
|
92
92
|
|
|
93
93
|
#### 3.2.2 选择扰动因子
|
|
@@ -6,9 +6,9 @@
|
|
|
6
6
|
|
|
7
7
|
| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | 加工具并使能 Md5 Dump (耗时) |
|
|
8
8
|
|:--------:|:--------:|:-------------------:|:--------------------:|:--------------------:|
|
|
9
|
-
| L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3
|
|
10
|
-
| L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0
|
|
11
|
-
| mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3
|
|
9
|
+
| L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3 s (膨胀10倍) |
|
|
10
|
+
| L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0 s (膨胀90倍) |
|
|
11
|
+
| mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3 s (膨胀112倍) |
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
## "tensor"模式采集数据量参考基线
|
msprobe/docs/32.ckpt_compare.md
CHANGED
|
@@ -19,11 +19,11 @@ Megatron、MindSpeed的ckpt加载依赖megatron,请确保megatron在python环
|
|
|
19
19
|
msprobe --framework pytorch config_check --compare path1 path2 -o output_path.json
|
|
20
20
|
```
|
|
21
21
|
|
|
22
|
-
| 参数名 | 解释
|
|
23
|
-
|
|
24
|
-
| -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。
|
|
25
|
-
| -c 或 --compare | 2个ckpt
|
|
26
|
-
| -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。
|
|
22
|
+
| 参数名 | 解释 | 是否必选 |
|
|
23
|
+
|--------|----------------------------------------------------------------------------------------------|--------|
|
|
24
|
+
| -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
|
|
25
|
+
| -c 或 --compare | 2个ckpt的路径。在ckpt传给工具加载前,用户需要确保ckpt是安全可信的,若ckpt来源官方有提供SHA256等校验值,用户必须要进行校验,以确保ckpt没有被篡改。 | 是 |
|
|
26
|
+
| -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
|
|
27
27
|
|
|
28
28
|
Megatron-LM 和 MindSpeed 的 ckpt 目录结构如下:
|
|
29
29
|
|
|
@@ -26,8 +26,7 @@ def read_npy_data(dir_path, file_name):
|
|
|
26
26
|
return None
|
|
27
27
|
|
|
28
28
|
data_path = os.path.join(dir_path, file_name)
|
|
29
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
30
|
-
FileCheckConst.NUMPY_SUFFIX, False)
|
|
29
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX)
|
|
31
30
|
data_path = path_checker.common_check()
|
|
32
31
|
data_value = load_npy(data_path)
|
|
33
32
|
return data_value
|
|
@@ -163,15 +163,11 @@ class GradContext:
|
|
|
163
163
|
def __init__(self) -> None:
|
|
164
164
|
self.pre = {}
|
|
165
165
|
self.post = {}
|
|
166
|
-
self.acc_metric = {}
|
|
167
|
-
self.acc = {}
|
|
168
166
|
self.actv = {}
|
|
169
167
|
|
|
170
168
|
def reset(self):
|
|
171
169
|
self.pre.clear()
|
|
172
170
|
self.post.clear()
|
|
173
|
-
self.acc_metric.clear()
|
|
174
|
-
self.acc.clear()
|
|
175
171
|
self.actv.clear()
|
|
176
172
|
|
|
177
173
|
|
|
@@ -312,7 +308,6 @@ class TrainerMon:
|
|
|
312
308
|
self.recording_l2_features = self.config.get('recording_l2_features', False)
|
|
313
309
|
self.sa_order = self.config.get('sa_order', "s,b,h,d")
|
|
314
310
|
|
|
315
|
-
|
|
316
311
|
if not self.cc_distribution.get('enable', False):
|
|
317
312
|
self.cc_log_only = False
|
|
318
313
|
else:
|
|
@@ -403,7 +398,7 @@ class TrainerMon:
|
|
|
403
398
|
if self.monitoring:
|
|
404
399
|
module_rank_valid = self.is_target_rank()
|
|
405
400
|
step_condition = (context.step >= self.start_step and (
|
|
406
|
-
|
|
401
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
407
402
|
if module_rank_valid and step_condition:
|
|
408
403
|
self.has_collect_times += 1
|
|
409
404
|
|
|
@@ -447,6 +442,7 @@ class TrainerMon:
|
|
|
447
442
|
hook(optimizer, args, kwargs)
|
|
448
443
|
step_final_hook(optimizer, args, kwargs)
|
|
449
444
|
return out
|
|
445
|
+
|
|
450
446
|
return wrapper
|
|
451
447
|
|
|
452
448
|
if self.is_mindtorch:
|
|
@@ -541,7 +537,7 @@ class TrainerMon:
|
|
|
541
537
|
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
542
538
|
if self.is_mindtorch:
|
|
543
539
|
context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
|
|
544
|
-
|
|
540
|
+
context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
545
541
|
else:
|
|
546
542
|
context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
|
|
547
543
|
|
|
@@ -564,7 +560,6 @@ class TrainerMon:
|
|
|
564
560
|
context = self.optimizer_context[optimizer]
|
|
565
561
|
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
566
562
|
|
|
567
|
-
|
|
568
563
|
if self.optimizer_hooked or not self.is_target_rank():
|
|
569
564
|
return
|
|
570
565
|
|
|
@@ -577,7 +572,6 @@ class TrainerMon:
|
|
|
577
572
|
if not self.wg_distribution:
|
|
578
573
|
return
|
|
579
574
|
|
|
580
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
581
575
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
582
576
|
|
|
583
577
|
def generate_param_map(self, tag, param_tensor):
|
|
@@ -671,7 +665,7 @@ class TrainerMon:
|
|
|
671
665
|
if not self.wg_distribution:
|
|
672
666
|
return
|
|
673
667
|
|
|
674
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
668
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
|
|
675
669
|
use_micro_step=self.monitor_mbs_grad)
|
|
676
670
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
677
671
|
|
|
@@ -810,7 +804,7 @@ class TrainerMon:
|
|
|
810
804
|
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
811
805
|
MonitorConst.ACTV, module_input))
|
|
812
806
|
module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
|
|
813
|
-
|
|
807
|
+
if isinstance(module_output, tuple) else module_output
|
|
814
808
|
tbtag_tensor_map.update(
|
|
815
809
|
self.build_tbtag_tensor_map(
|
|
816
810
|
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
@@ -868,11 +862,13 @@ class TrainerMon:
|
|
|
868
862
|
if version.parse(mindspore.__version__) >= version.parse('2.6.0'):
|
|
869
863
|
def wrapper(module, args, kwargs, module_output):
|
|
870
864
|
return fwd_hook_fun(module, args, kwargs, module_output, name)
|
|
865
|
+
|
|
871
866
|
return module.register_forward_hook(wrapper, with_kwargs=True)
|
|
872
867
|
|
|
873
868
|
else:
|
|
874
869
|
def wrapper(module, args, module_output):
|
|
875
870
|
return fwd_hook_fun(module, args, None, module_output, name)
|
|
871
|
+
|
|
876
872
|
return module.register_forward_hook(wrapper)
|
|
877
873
|
|
|
878
874
|
def extract_attention_feature_hook(module, args, kwargs, module_output, name):
|
|
@@ -880,7 +876,7 @@ class TrainerMon:
|
|
|
880
876
|
if kwargs:
|
|
881
877
|
kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
|
|
882
878
|
module_input.extend(kwargs_tensors)
|
|
883
|
-
|
|
879
|
+
|
|
884
880
|
if module not in self.feature_hook_context_by_module:
|
|
885
881
|
self.feature_hook_context_by_module[module] = FeatureHookContext(name)
|
|
886
882
|
context: FeatureHookContext = self.feature_hook_context_by_module[module]
|
|
@@ -890,7 +886,7 @@ class TrainerMon:
|
|
|
890
886
|
logger.warning(
|
|
891
887
|
"Calculate attention feature failed, the length of module_input in attention hook's module should "
|
|
892
888
|
"be greater than or equal to 2.")
|
|
893
|
-
|
|
889
|
+
|
|
894
890
|
q_h = module_input[0]
|
|
895
891
|
k_h = module_input[1]
|
|
896
892
|
qkt = cal_qkt(q_h, k_h, order=self.sa_order)
|
|
@@ -985,25 +981,26 @@ class TrainerMon:
|
|
|
985
981
|
self._hook_weights()
|
|
986
982
|
|
|
987
983
|
def _hook_weights(self):
|
|
988
|
-
context = self.grad_context
|
|
989
984
|
|
|
990
985
|
@_no_grad()
|
|
991
|
-
def param_hook(grad,
|
|
986
|
+
def param_hook(grad, param, name):
|
|
992
987
|
key = name
|
|
993
988
|
if self.monitor_mbs_grad:
|
|
994
989
|
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
995
990
|
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
996
991
|
self.register_param_call_id("param_hook", key)
|
|
997
992
|
param.micro_step += 1
|
|
998
|
-
|
|
993
|
+
grad_dict = {}
|
|
999
994
|
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1000
|
-
|
|
995
|
+
grad_dict[key] = grad
|
|
996
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
997
|
+
|
|
1001
998
|
if param.micro_step == self.micro_batch_number:
|
|
1002
999
|
param.micro_step = 0
|
|
1003
1000
|
|
|
1004
|
-
def param_hook_wrapper(param_hook,
|
|
1001
|
+
def param_hook_wrapper(param_hook, param, name):
|
|
1005
1002
|
def wrapper(grad):
|
|
1006
|
-
return param_hook(grad,
|
|
1003
|
+
return param_hook(grad, param, name)
|
|
1007
1004
|
|
|
1008
1005
|
return wrapper
|
|
1009
1006
|
|
|
@@ -1011,7 +1008,7 @@ class TrainerMon:
|
|
|
1011
1008
|
for param, name in self.param2name.items():
|
|
1012
1009
|
setattr(param, 'micro_step', 0)
|
|
1013
1010
|
handle = param.register_hook(
|
|
1014
|
-
param_hook_wrapper(param_hook,
|
|
1011
|
+
param_hook_wrapper(param_hook, param=param, name=name))
|
|
1015
1012
|
self.handles['wgrads'].append(handle)
|
|
1016
1013
|
self.weight_hooked = True
|
|
1017
1014
|
|
msprobe/msprobe.py
CHANGED
|
@@ -14,16 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import argparse
|
|
17
|
-
import sys
|
|
18
17
|
import importlib.util
|
|
18
|
+
import sys
|
|
19
19
|
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.common.file_utils import root_privilege_warning
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
22
|
-
from msprobe.core.compare.utils import _compare_parser
|
|
23
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
24
24
|
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
25
|
-
from msprobe.core.
|
|
26
|
-
|
|
25
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
26
|
+
from msprobe.core.config_check.config_check_cli import _config_checking_parser, _run_config_checking_command
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def is_module_available(module_name):
|
|
@@ -64,6 +64,8 @@ def main():
|
|
|
64
64
|
if len(sys.argv) < 4:
|
|
65
65
|
parser.print_help()
|
|
66
66
|
sys.exit(0)
|
|
67
|
+
|
|
68
|
+
root_privilege_warning()
|
|
67
69
|
framework_args = parser.parse_args(sys.argv[1:3])
|
|
68
70
|
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
69
71
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
@@ -39,7 +39,12 @@ from msprobe.core.common.const import FileCheckConst, Const
|
|
|
39
39
|
from msprobe.core.common.utils import CompareException
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
def split_json_file(input_file, num_splits, filter_api):
|
|
42
|
+
def split_json_file(input_file, num_splits, filter_api, device_id):
|
|
43
|
+
max_processes = len(device_id) * 8
|
|
44
|
+
if num_splits > max_processes:
|
|
45
|
+
logger.warning(f"A device supports a maximum of 8 processes. "
|
|
46
|
+
f"The total number of processes exceeds the limit, and it is set to {max_processes}.")
|
|
47
|
+
num_splits = max_processes
|
|
43
48
|
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
44
49
|
input_dir = os.path.dirname(os.path.abspath(input_file))
|
|
45
50
|
if filter_api:
|
|
@@ -88,7 +93,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
88
93
|
logger.error(f"File not found or could not be deleted: {file}")
|
|
89
94
|
msg = 'ERROR: Split json file failed, please check the input file and try again.'
|
|
90
95
|
raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
|
|
91
|
-
return split_files, total_items
|
|
96
|
+
return split_files, total_items, num_splits
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
def signal_handler(signum, frame):
|
|
@@ -127,7 +132,8 @@ def run_parallel_ut(config):
|
|
|
127
132
|
def read_process_output(process):
|
|
128
133
|
try:
|
|
129
134
|
while True:
|
|
130
|
-
|
|
135
|
+
# 子进程标准输出流与进程本身状态是分开的,因此增加判断。子进程返回值非None表示子进程结束,标准输出为None表示结束。
|
|
136
|
+
if process.poll() is not None or process.stdout is None:
|
|
131
137
|
break
|
|
132
138
|
output = process.stdout.readline()
|
|
133
139
|
if output == '':
|
|
@@ -175,12 +181,17 @@ def run_parallel_ut(config):
|
|
|
175
181
|
|
|
176
182
|
try:
|
|
177
183
|
for process in processes:
|
|
178
|
-
process.
|
|
184
|
+
process.wait() # wait仅阻塞,不捕获标准输出和标准错误,原communicate不仅阻塞,而且捕获标准输出和标准错误
|
|
179
185
|
except KeyboardInterrupt:
|
|
180
186
|
logger.warning("Interrupted by user, terminating processes and cleaning up...")
|
|
181
187
|
except Exception as e:
|
|
182
188
|
logger.error(f"An unexpected error occurred: {e}")
|
|
183
189
|
finally:
|
|
190
|
+
# 最后再更新一次进度条,避免因缓存写入等原因子进程结束而进度未刷新的问题
|
|
191
|
+
if wait_for_file_write_complete(config.result_csv_path):
|
|
192
|
+
result_file = read_csv(config.result_csv_path)
|
|
193
|
+
completed_items = len(result_file)
|
|
194
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
184
195
|
if progress_bar.n < config.total_items:
|
|
185
196
|
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
|
|
186
197
|
"the result CSV file will be utilized to resume the UT task.")
|
|
@@ -195,6 +206,22 @@ def run_parallel_ut(config):
|
|
|
195
206
|
logger.error(f"An unexpected error occurred: {e}")
|
|
196
207
|
|
|
197
208
|
|
|
209
|
+
def wait_for_file_write_complete(file_path, timeout=3600):
|
|
210
|
+
last_size = 0
|
|
211
|
+
start_time = time.time() # 记录开始时间
|
|
212
|
+
while True:
|
|
213
|
+
current_size = os.path.getsize(file_path)
|
|
214
|
+
# 检查是否文件大小未变化
|
|
215
|
+
if current_size == last_size:
|
|
216
|
+
return True # 文件写入完成,返回 True
|
|
217
|
+
last_size = current_size
|
|
218
|
+
# 检查是否超时
|
|
219
|
+
if time.time() - start_time > timeout:
|
|
220
|
+
logger.error("write the result csv file timeout.")
|
|
221
|
+
return False # 超时,返回 False
|
|
222
|
+
time.sleep(0.1) # 适当的延时
|
|
223
|
+
|
|
224
|
+
|
|
198
225
|
def prepare_config(args):
|
|
199
226
|
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
200
227
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
@@ -203,7 +230,9 @@ def prepare_config(args):
|
|
|
203
230
|
create_directory(out_path)
|
|
204
231
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
205
232
|
out_path = out_path_checker.common_check()
|
|
206
|
-
split_files, total_items = split_json_file(api_info, args.num_splits,
|
|
233
|
+
split_files, total_items, modified_num_splits = split_json_file(api_info, args.num_splits,
|
|
234
|
+
args.filter_api, args.device_id)
|
|
235
|
+
args.num_splits = modified_num_splits
|
|
207
236
|
config_path = args.config_path if args.config_path else None
|
|
208
237
|
if config_path:
|
|
209
238
|
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -39,7 +39,6 @@ except ImportError:
|
|
|
39
39
|
else:
|
|
40
40
|
is_gpu = False
|
|
41
41
|
|
|
42
|
-
|
|
43
42
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
44
43
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
45
44
|
|
|
@@ -338,56 +337,6 @@ def save_pt(tensor, filepath):
|
|
|
338
337
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
339
338
|
|
|
340
339
|
|
|
341
|
-
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
342
|
-
"""
|
|
343
|
-
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
344
|
-
It overrides the find_class method to add type checking functionality.
|
|
345
|
-
"""
|
|
346
|
-
allowed_types = [
|
|
347
|
-
"str",
|
|
348
|
-
"ApiData",
|
|
349
|
-
"OrderedDict",
|
|
350
|
-
"_rebuild_tensor_v2", # from torch.utils
|
|
351
|
-
"_load_from_bytes" # from torch.storage
|
|
352
|
-
]
|
|
353
|
-
|
|
354
|
-
def find_class(self, module, name):
|
|
355
|
-
"""
|
|
356
|
-
Method to find the class of the object to be unpickled.
|
|
357
|
-
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
358
|
-
"""
|
|
359
|
-
if name in self.allowed_types:
|
|
360
|
-
return super().find_class(module, name)
|
|
361
|
-
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def save_pkl(tensor, filepath):
|
|
365
|
-
"""Save ApiData or str objection by pickle"""
|
|
366
|
-
check_path_before_create(filepath)
|
|
367
|
-
filepath = os.path.realpath(filepath)
|
|
368
|
-
try:
|
|
369
|
-
with FileOpen(filepath, 'wb') as f:
|
|
370
|
-
pickle.dump(tensor, f)
|
|
371
|
-
except Exception as e:
|
|
372
|
-
logger.error("Save pt file failed, please check according possible error causes: "
|
|
373
|
-
"1. out of disk space or disk error, "
|
|
374
|
-
"2. no permission to write files, etc.")
|
|
375
|
-
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
376
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
def load_pkl(pt_path):
|
|
380
|
-
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
381
|
-
check_file_or_directory_path(pt_path)
|
|
382
|
-
pt_path = os.path.realpath(pt_path)
|
|
383
|
-
try:
|
|
384
|
-
with FileOpen(pt_path, 'rb') as f:
|
|
385
|
-
pt = TypeCheckingUnpickler(f).load()
|
|
386
|
-
except Exception as e:
|
|
387
|
-
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
388
|
-
return pt
|
|
389
|
-
|
|
390
|
-
|
|
391
340
|
def is_recomputation():
|
|
392
341
|
"""Check if the current operation is in the re-computation phase.
|
|
393
342
|
|
|
@@ -416,7 +365,8 @@ def is_recomputation():
|
|
|
416
365
|
|
|
417
366
|
# Identify indices in the call stack where the specific function is being executed
|
|
418
367
|
for idx, frame_info in enumerate(call_stack):
|
|
419
|
-
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward'
|
|
368
|
+
if (frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward' and
|
|
369
|
+
"megatron" in frame_info.filename):
|
|
420
370
|
backward_function_indices.append(idx)
|
|
421
371
|
|
|
422
372
|
# Check if the execution is within 'torch/autograd/function.py' file
|
msprobe/pytorch/compare/utils.py
CHANGED
|
@@ -27,8 +27,7 @@ def read_pt_data(dir_path, file_name):
|
|
|
27
27
|
return None
|
|
28
28
|
|
|
29
29
|
data_path = os.path.join(dir_path, file_name)
|
|
30
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
31
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
30
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.PT_SUFFIX)
|
|
32
31
|
data_path = path_checker.common_check()
|
|
33
32
|
try:
|
|
34
33
|
# detach because numpy can not process gradient information
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from functools import wraps
|
|
17
|
+
from typing import Any, Callable
|
|
17
18
|
|
|
18
19
|
import torch
|
|
19
20
|
from torch.utils.hooks import BackwardHook
|
|
@@ -21,6 +22,9 @@ from torch.utils.hooks import BackwardHook
|
|
|
21
22
|
from msprobe.core.common.const import Const
|
|
22
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
24
|
from msprobe.pytorch.common.log import logger
|
|
25
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
26
|
+
|
|
27
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
def wrap_setup_backward_hook(func):
|
|
@@ -92,3 +96,23 @@ def wrap_setup_backward_hook(func):
|
|
|
92
96
|
def wrap_setup_input_output_hook():
|
|
93
97
|
BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
|
|
94
98
|
BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_apply_func_wrapper(original_func: Callable) -> Callable:
|
|
102
|
+
@wraps(original_func)
|
|
103
|
+
def wrapped_apply(*args, **kwargs) -> Any:
|
|
104
|
+
api_register = get_api_register()
|
|
105
|
+
if api_register:
|
|
106
|
+
api_register.restore_inner_used_api()
|
|
107
|
+
result = original_func(*args, **kwargs)
|
|
108
|
+
if api_register:
|
|
109
|
+
api_register.register_inner_used_api()
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
return wrapped_apply
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def wrap_backward_hook_function_apply():
|
|
116
|
+
if torch_version_above_or_equal_2:
|
|
117
|
+
original_apply = torch.nn.modules._functions.BackwardHookFunction.apply
|
|
118
|
+
torch.nn.modules._functions.BackwardHookFunction.apply = get_apply_func_wrapper(original_apply)
|