mindstudio-probe 8.2.1__py3-none-any.whl → 8.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +46 -37
- msprobe/README.md +3 -1
- msprobe/core/common/file_utils.py +80 -25
- msprobe/core/common/framework_adapter.py +7 -6
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +4 -16
- msprobe/core/compare/find_first/utils.py +1 -1
- msprobe/core/hook_manager.py +16 -3
- msprobe/core/service.py +16 -5
- msprobe/docs/02.config_introduction.md +14 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
- msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +3 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/mindspore/compare/utils.py +1 -2
- msprobe/msprobe.py +6 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
- msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/common/utils.py +22 -2
- msprobe/pytorch/compare/utils.py +1 -2
- msprobe/pytorch/debugger/debugger_config.py +10 -0
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +24 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +9 -3
- msprobe/pytorch/hook_module/api_register.py +6 -1
- msprobe/pytorch/pt_config.py +57 -2
- msprobe/pytorch/pytorch_service.py +11 -2
- msprobe/visualization/builder/graph_builder.py +1 -0
- msprobe/visualization/utils.py +11 -1
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +0 -3
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.2.1.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
msprobe/core/service.py
CHANGED
|
@@ -35,6 +35,7 @@ class BaseService(ABC):
|
|
|
35
35
|
self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
|
|
36
36
|
self.model = None
|
|
37
37
|
self.data_collector = build_data_collector(self.config)
|
|
38
|
+
self.attl_manager = None
|
|
38
39
|
self.current_iter = 0
|
|
39
40
|
self.loop = 0
|
|
40
41
|
self.init_step = 0
|
|
@@ -90,6 +91,10 @@ class BaseService(ABC):
|
|
|
90
91
|
self.config.task in self.data_collector.tasks_need_tensor_data or
|
|
91
92
|
(self.config.task == Const.STATISTICS and self.config.tensor_list)
|
|
92
93
|
)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def _is_online_run_ut(self):
|
|
97
|
+
return getattr(self.config, "online_run_ut", False)
|
|
93
98
|
|
|
94
99
|
@property
|
|
95
100
|
@abstractmethod
|
|
@@ -141,9 +146,11 @@ class BaseService(ABC):
|
|
|
141
146
|
self.primitive_switch = True
|
|
142
147
|
self._change_jit_switch(True)
|
|
143
148
|
self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
149
|
+
if self._is_online_run_ut:
|
|
150
|
+
self._run_ut_dispatch(True)
|
|
151
|
+
else:
|
|
152
|
+
self.create_dirs()
|
|
153
|
+
self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
147
154
|
|
|
148
155
|
def stop(self):
|
|
149
156
|
"""通用stop模板"""
|
|
@@ -158,7 +165,8 @@ class BaseService(ABC):
|
|
|
158
165
|
self._change_jit_switch(False)
|
|
159
166
|
if self._is_l2_level:
|
|
160
167
|
return
|
|
161
|
-
|
|
168
|
+
if self._is_online_run_ut:
|
|
169
|
+
self._run_ut_dispatch(False)
|
|
162
170
|
self._process_async_dump()
|
|
163
171
|
self.data_collector.write_json()
|
|
164
172
|
|
|
@@ -258,6 +266,8 @@ class BaseService(ABC):
|
|
|
258
266
|
end_service = self.config.step and self.current_iter > max(self.config.step) or \
|
|
259
267
|
self.data_collector and self.data_collector.data_processor.is_terminated
|
|
260
268
|
if end_service:
|
|
269
|
+
if self._is_online_run_ut and self.attl_manager:
|
|
270
|
+
self.attl_manager.attl_stop()
|
|
261
271
|
self.primitive_switch = False
|
|
262
272
|
self._change_jit_switch(False)
|
|
263
273
|
Runtime.is_running = False
|
|
@@ -300,7 +310,8 @@ class BaseService(ABC):
|
|
|
300
310
|
if root_model and isinstance(root_model, list):
|
|
301
311
|
root_model = root_model[0]
|
|
302
312
|
self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
|
|
303
|
-
|
|
313
|
+
if self._is_online_run_ut:
|
|
314
|
+
return
|
|
304
315
|
root_model.register_forward_pre_hook(infer_hook)
|
|
305
316
|
|
|
306
317
|
def _create_l2_dirs(self, cur_rank):
|
|
@@ -73,7 +73,14 @@
|
|
|
73
73
|
| data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 |
|
|
74
74
|
| file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:<br/> "bin":dump 的 tensor 文件为二进制格式;<br/>"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 |
|
|
75
75
|
| summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:<br/> md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;<br/> statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 |
|
|
76
|
+
| online_run_ut<sup>a</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 |
|
|
77
|
+
| nfs_path<sup>a</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 |
|
|
78
|
+
| host<sup>a</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
79
|
+
| port<sup>a</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 |
|
|
76
80
|
|
|
81
|
+
**说明**:
|
|
82
|
+
|
|
83
|
+
1. online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。
|
|
77
84
|
|
|
78
85
|
**示例**:
|
|
79
86
|
- [PyTorch场景](03.config_examples.md#12-task-配置为-tensor)
|
|
@@ -88,11 +95,17 @@
|
|
|
88
95
|
| white_list<sup>a</sup> | API dump 白名单,仅对指定的 API 进行 dump。<br/>**配置示例**:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即 dump 全量 API 数据。 | 否 |
|
|
89
96
|
| black_list<sup>a</sup> | API dump 黑名单,被指定的 API 不进行 dump。<br/>**配置示例**:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即 dump 全量 API 数据。 | 否 |
|
|
90
97
|
| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,默认为当前路径。<br/>**配置示例**:"error_data_path": "./"。 | 否 |
|
|
98
|
+
| is_online<sup>b</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认关闭。 | 否 |
|
|
99
|
+
| nfs_path<sup>b</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效,仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
100
|
+
| host<sup>b</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
101
|
+
| port<sup>b</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 |
|
|
102
|
+
| rank_list<sup>b</sup> | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
91
103
|
|
|
92
104
|
**说明**:
|
|
93
105
|
|
|
94
106
|
1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。
|
|
95
107
|
|
|
108
|
+
2. is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。
|
|
96
109
|
|
|
97
110
|
**示例**:
|
|
98
111
|
```json
|
|
@@ -145,7 +158,7 @@ PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L
|
|
|
145
158
|
<tr><td>pert_mode</td><td>无标杆扰动因子,str 类型。可选参数:<br/> "improve_precision":对输入做升精度,默认值;<br/> "add_noise":对输入增加噪声;<br/> "no_change":不加扰动直接二次执行;<br/> "bit_noise":输入的末位比特翻转,MindSpore 场景不支持 BF16 类型的向量;<br/> "change_value":输入的张量首尾值调换;<br/> "to_cpu":在 CPU 等价执行(仅 PyTorch 场景支持)。<br/><b>配置示例</b>:"pert_mode": "improve_precision"。</td><td>否</td></tr>
|
|
146
159
|
<tr><td>handler_type</td><td>处理类型,可选参数:<br/> "check":进行无标杆比对检查,默认值;<br/> "fix":将扰动后的 API 输出结果覆盖原始 API 输出结果,尝试将 Loss 曲线恢复正常,该模式下不支持预热功能与反向过程,且仅支持"improve_precision"、"to_cpu"( PyTorch 场景)两种扰动因子。<br/> <b>配置示例</b>:"handler_type": "check"。</td><td>否</td></tr>
|
|
147
160
|
<tr><td>fuzz_level</td><td>无标杆数据 dump 级别,即选择比对结果文件应输出的表头属性,当前仅支持取值为:"L1"。输出结果详见 <a href="#161-无标杆比对数据存盘格式">1.6.1 无标杆比对数据存盘格式</a>。</td><td>否</td></tr>
|
|
148
|
-
<tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"
|
|
161
|
+
<tr><td>fuzz_stage</td><td>比对过程,选择对 API 前向或反向进行无标杆比对,可选参数:<br/> "forward":前向,默认值;<br/> "backward":反向。当 fuzz_stage 为 "backward" 时,handler_type 只能为 "check"。pytorch场景下,当 fuzz_stage 为 "backward" 时, list 参数不能为空。<br/> <b>配置示例</b>:"fuzz_stage": "backward"。</td><td>否</td></tr>
|
|
149
162
|
<tr><td>if_preheat</td><td>预热功能(仅 PyTorch 场景支持),bool 类型。开启功能后工具可以根据每次迭代的输出调整精度算法的阈值,从而更准确地找出存在精度问题的 API。当"handler_type": "fix"时,不支持预热。可选参数:<br/> true(开启)或 false(关闭),默认关闭。<br/> <b>配置示例</b>:"if_preheat": "true"。</td><td>否</td></tr>
|
|
150
163
|
<tr><td>preheat_step</td><td>开启预热的迭代数量(仅 PyTorch 场景支持),int 类型,默认值为 15。须配置 "if_preheat": "true"。</td><td>否</td></tr>
|
|
151
164
|
<tr><td>max_sample</td><td>每个算子预热的采样次数的最大阈值(仅 PyTorch 场景支持),int 类型,默认值为 20。须配置 "if_preheat": "true"。</td><td>否</td></tr>
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# PyTorch 场景的在线精度预检
|
|
2
|
+
|
|
3
|
+
## 1 简介
|
|
4
|
+
|
|
5
|
+
为了应对大模型场景下,通过离线预检方式 dump API 输入输出数据导致的存储资源紧张问题,提供在线精度预检功能。本功能实现在执行 NPU 训练操作的过程中,通过 TCP/IP 协议在 NPU
|
|
6
|
+
Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据在 GPU 设备上运行,将两份输出数据进行比对,得到预检比对结果,从而减少数据 dump 的步骤,降低存储资源的占用。针对偏差较大的算子,两方比对(NPU vs. GPU)的方法缺少裁判进行裁定。 参考离线预检,在线预检场景同时支持两方比对和三方比对方式,按照 api 的精度标准要求,选择两方比对或三方比对。
|
|
7
|
+
|
|
8
|
+
## 2 在线精度预检流程
|
|
9
|
+
|
|
10
|
+
在线精度预检当前支持**局域网场景**和**共享存储场景**,请根据不同的场景选择对应的配置。
|
|
11
|
+
|
|
12
|
+
在线精度预检操作流程如下:
|
|
13
|
+
|
|
14
|
+
1. 准备 GPU 和 NPU 可正常运行的训练环境,PyTorch 版本大于等于2.0,并保证两台 Host 在同一局域网内可正常通信或能通过共享存储进行通信。
|
|
15
|
+
2. GPU 和 NPU Host 设备上同时安装msprobe工具,详见[ msprobe 安装](./01.installation.md)章节,其中在线预检要安装 twisted、pyOpenSSL,这些包为 Python 模块。
|
|
16
|
+
3. 分别配置 GPU 侧、NPU 侧的 config.json 文件。
|
|
17
|
+
4. 在 GPU 侧运行 `msprobe -f pytorch run_ut -config ./config.json`。
|
|
18
|
+
5. 在 NPU 侧配置训练脚本。
|
|
19
|
+
6. 在 NPU 侧执行训练。
|
|
20
|
+
|
|
21
|
+
## 3 在线精度预检操作指导
|
|
22
|
+
|
|
23
|
+
### 3.1 配置 config.json 文件
|
|
24
|
+
|
|
25
|
+
预检工具安装完成后,需要在 GPU 和 NPU 环境下分别配置 config.json。其中需要重点关注文件中的 is_online、is_benchmark_device、host 和 port 参数的配置,保障在线预检时 GPU 和 NPU 两台设备间的通信正常。
|
|
26
|
+
|
|
27
|
+
#### 3.1.1 GPU 侧在线预检配置说明
|
|
28
|
+
|
|
29
|
+
| 参数名称 | 说明 | 是否必选 |
|
|
30
|
+
|-----------------|--------------|------|
|
|
31
|
+
| task | 任务名称,str 类型,配置为 run_ut 表示预检任务。通过其他字段 is_online 判断离线预检、在线预检任务。 | 是 |
|
|
32
|
+
| white_list | 预检的 API 白名单,list[str] 类型。<br/>**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置白名单,即预检全量 API 数据。 | 否 |
|
|
33
|
+
| black_list | 预检的 API 黑名单,list[str] 类型。<br/>**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置黑名单,即预检全量 API 数据。 | 否 |
|
|
34
|
+
| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,str 类型。在线预检模式下该参数不生效。 | 否 |
|
|
35
|
+
| is_online | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
|
|
36
|
+
| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host、port 和 tls_path 不生效。 | 否 |
|
|
37
|
+
| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
38
|
+
| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
39
|
+
| rank_list | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。 | 是 |
|
|
40
|
+
| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 server.key、证书 server.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
#### 3.1.2 NPU 侧在线预检配置说明
|
|
44
|
+
|
|
45
|
+
| 参数名称 | 说明 | 是否必选 |
|
|
46
|
+
|------------------|-------------|------|
|
|
47
|
+
| task | 任务名称,str 类型,配置为 tensor 表示 dump API 统计信息和完全复刻整网的 API 运行情况的真实数据。通过字段 online_run_ut 判断是否使用在线预检功能。 | 是 |
|
|
48
|
+
| dump_path | dump 路径,str 类型,配置为合法路径即可,兼容 tensor 任务静态检查。 | 是 |
|
|
49
|
+
| level | dump 级别,str 类型,在线预检时配置为 L1,表示 dump API 级精度数据。在线预检可不配置,默认取值 L1。 | 是 |
|
|
50
|
+
| rank | 指定对某张卡上的数据进行 dump,list[int] 类型,默认未配置(表示 dump所有卡的数据),需要与 GPU 侧配置项 rank_list 保持一致。 | 否 |
|
|
51
|
+
| step | 指定 dump 某个 step 的数据,list[int] 类型,默认未配置,表示 dump 所有 step 的数据。dump 特定 step 时,须指定为训练脚本中存在的 step。 | 否 |
|
|
52
|
+
| scope | dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
|
|
53
|
+
| list | dump 范围,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
|
|
54
|
+
| online_run_ut | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
|
|
55
|
+
| nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效。 | 否 |
|
|
56
|
+
| host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
57
|
+
| port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
58
|
+
| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 client.key、证书 client.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
|
|
59
|
+
| online_run_ut_recompute | 模型训练是否使用重计算机制,bool类型,默认为False,表示模型没有使用重计算。在线预检暂不支持重计算机制下反向算子的预检,当模型训练使用重计算时,跳过反向算子预检,默认模型关闭重计算。 | 否 |
|
|
60
|
+
|
|
61
|
+
#### 3.1.3 局域网场景配置示例
|
|
62
|
+
|
|
63
|
+
若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法。
|
|
64
|
+
|
|
65
|
+
以下秘钥生成方法仅为简单示例,客户应使用与自己需求相符的秘钥生成和存储机制并保证秘钥安全性与机密性,必要时可采用分层秘钥机制。
|
|
66
|
+
以下示例中加密口令仅供参考,使用时请更换为复杂口令,并保护口令安全。
|
|
67
|
+
```shell
|
|
68
|
+
# 生成CA证书的根私钥和证书签名请求,其中ca_password为CA私钥加密口令,仅作演示,请更换使用
|
|
69
|
+
openssl req -new -newkey rsa:3072 -passout pass:ca_password -subj "/CN=*ca.com/O=ca.Inc./C=CN/ST=Zhejiang/L=Hangzhou" -keyout ca.key -out ca.csr
|
|
70
|
+
# 自签发根证书
|
|
71
|
+
openssl x509 -req -days 365 -in ca.csr -signkey ca.key -passin pass:ca_password -out ca.crt -extensions v3_ca -extfile <(cat <<-EOF
|
|
72
|
+
[v3_ca]
|
|
73
|
+
basicConstraints = critical,CA:true
|
|
74
|
+
keyUsage = critical, keyCertSign, cRLSign
|
|
75
|
+
EOF
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# 生成client公私钥,其中client_password为私钥加密口令,仅作演示,请更换使用
|
|
79
|
+
openssl genrsa -aes256 -passout pass:client_password -out client.key 3072
|
|
80
|
+
# 基于client公私钥生成签名请求
|
|
81
|
+
openssl req -new -key client.key -passin pass:client_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out client.csr
|
|
82
|
+
# 利用自签发的根证书,签发client证书
|
|
83
|
+
openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in client.csr -out client.crt -CAcreateserial -extfile <(cat <<-EOF
|
|
84
|
+
[v3_server]
|
|
85
|
+
basicConstraints = CA:FALSE
|
|
86
|
+
keyUsage = critical, digitalSignature, keyEncipherment
|
|
87
|
+
extendedKeyUsage = serverAuth
|
|
88
|
+
EOF
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# 生成server公私钥,其中server_password为私钥加密口令,仅作演示,请更换使用
|
|
92
|
+
openssl genrsa -aes256 -passout pass:server_password -out server.key 3072
|
|
93
|
+
# 基于server公私钥生成签名请求
|
|
94
|
+
openssl req -new -key server.key -passin pass:server_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out server.csr
|
|
95
|
+
# 利用自签发的根证书,签发server证书
|
|
96
|
+
openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in server.csr -out server.crt -CAcreateserial -extfile <(cat <<-EOF
|
|
97
|
+
[v3_server]
|
|
98
|
+
basicConstraints = CA:FALSE
|
|
99
|
+
keyUsage = critical, digitalSignature, keyEncipherment
|
|
100
|
+
extendedKeyUsage = serverAuth
|
|
101
|
+
EOF
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
当需要吊销已创建的SSL证书时,通过openssl命令生成CRL证书 crl.pem,示例如下:
|
|
107
|
+
```shell
|
|
108
|
+
# 创建证书信息的文本数据库,空文件即可
|
|
109
|
+
touch index.txt
|
|
110
|
+
|
|
111
|
+
# 创建ca配置文件ca.cnf,内容如下,用于吊销证书使用
|
|
112
|
+
[ca]
|
|
113
|
+
default_ca = CA_default
|
|
114
|
+
[CA_default]
|
|
115
|
+
database = ./index.txt
|
|
116
|
+
default_md = sha256
|
|
117
|
+
|
|
118
|
+
# 吊销证书 client.crt,其中ca_password为CA私钥加密口令,与CA创建时保持一致
|
|
119
|
+
openssl ca -revoke client.crt -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password
|
|
120
|
+
# 生成CRL文件
|
|
121
|
+
openssl ca -gencrl -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password -out crl.pem -crldays 30
|
|
122
|
+
# 查看生成的CRL文件内容:
|
|
123
|
+
openssl工具的命令: openssl crl -inform PEM -in crl.pem -text
|
|
124
|
+
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
注意:配置TLS协议时,传输性能受机器环境和网络质量的影响,可能触发NPU超时中断模型训练,为避免训练和预检中断,丢弃长时间未传输的api数据,同时NPU侧配置HCCL环境变量,配置方式如下:
|
|
128
|
+
|
|
129
|
+
a) 调整HCCL环境变量,关闭看门狗,避免WorkHCCL超时中断模型训练:
|
|
130
|
+
```shell
|
|
131
|
+
export HCCL_DESYNC_DEBUG=0
|
|
132
|
+
export HCCL_ASYNC_ERROR_HANDLING=0
|
|
133
|
+
```
|
|
134
|
+
b) 调整通信算子超时设置(以1800s举例):
|
|
135
|
+
```shell
|
|
136
|
+
export HCCL_CONNECT_TIMEOUT=1800
|
|
137
|
+
export HCCL_EXEC_TIMEOUT=1800
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
GPU 侧:
|
|
141
|
+
|
|
142
|
+
```json
|
|
143
|
+
{
|
|
144
|
+
"task": "run_ut",
|
|
145
|
+
"run_ut": {
|
|
146
|
+
"white_list": [],
|
|
147
|
+
"black_list": [],
|
|
148
|
+
"error_data_path": "./",
|
|
149
|
+
"is_online": true,
|
|
150
|
+
"nfs_path": "",
|
|
151
|
+
"host": "127.0.0.1",
|
|
152
|
+
"port": 59208,
|
|
153
|
+
"rank_list": [0],
|
|
154
|
+
"tls_path": ""
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
NPU 侧:
|
|
160
|
+
|
|
161
|
+
```json
|
|
162
|
+
{
|
|
163
|
+
"task": "tensor",
|
|
164
|
+
"dump_path": "./dump_path",
|
|
165
|
+
"rank": [0],
|
|
166
|
+
"step": [0],
|
|
167
|
+
"level": "L1",
|
|
168
|
+
"tensor": {
|
|
169
|
+
"scope": [],
|
|
170
|
+
"list": [],
|
|
171
|
+
"online_run_ut": true,
|
|
172
|
+
"nfs_path": "",
|
|
173
|
+
"host": "xx.xx.xx.x",
|
|
174
|
+
"port": 59208,
|
|
175
|
+
"tls_path": ""
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
#### 3.1.4 共享存储场景配置示例
|
|
181
|
+
|
|
182
|
+
GPU 侧:
|
|
183
|
+
|
|
184
|
+
```json
|
|
185
|
+
{
|
|
186
|
+
"task": "run_ut",
|
|
187
|
+
"run_ut": {
|
|
188
|
+
"white_list": [],
|
|
189
|
+
"black_list": [],
|
|
190
|
+
"error_data_path": "./",
|
|
191
|
+
"is_online": true,
|
|
192
|
+
"nfs_path": "/nfs/xxx/data",
|
|
193
|
+
"host": "",
|
|
194
|
+
"port": -1,
|
|
195
|
+
"rank_list": [0],
|
|
196
|
+
"tls_path": ""
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
NPU 侧:
|
|
202
|
+
|
|
203
|
+
```json
|
|
204
|
+
{
|
|
205
|
+
"task": "tensor",
|
|
206
|
+
"dump_path": "./dump_path",
|
|
207
|
+
"rank": [0],
|
|
208
|
+
"step": [0],
|
|
209
|
+
"level": "L1",
|
|
210
|
+
"tensor": {
|
|
211
|
+
"scope": [],
|
|
212
|
+
"list": [],
|
|
213
|
+
"online_run_ut": true,
|
|
214
|
+
"nfs_path": "/nfs/xxx/data",
|
|
215
|
+
"host": "",
|
|
216
|
+
"port": -1,
|
|
217
|
+
"tls_path": ""
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
### 3.2 在 GPU 侧运行 run_ut
|
|
223
|
+
|
|
224
|
+
由于 GPU 侧为通信接收端,需先于 NPU 侧执行 run_ut 操作,命令如下:
|
|
225
|
+
|
|
226
|
+
```bash
|
|
227
|
+
msprobe -f pytorch run_ut -config ./config.json
|
|
228
|
+
```
|
|
229
|
+
|
|
230
|
+
GPU 侧配置好 config.json 文件后执行 run_ut 命令,此时 GPU 处于预检等待状态:
|
|
231
|
+
|
|
232
|
+
- 局域网场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到 GPU 侧时,GPU 启动预检操作。
|
|
233
|
+
- 共享存储场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到共享存储时,GPU 启动预检操作。
|
|
234
|
+
|
|
235
|
+
### 3.3 在 NPU 侧配置训练脚本
|
|
236
|
+
|
|
237
|
+
在 NPU 训练脚本中添加如下代码以获取 run_ut 操作的预检 API 输入和输出数据:
|
|
238
|
+
|
|
239
|
+
```python
|
|
240
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
241
|
+
|
|
242
|
+
debugger = PrecisionDebugger("config.json")
|
|
243
|
+
...
|
|
244
|
+
|
|
245
|
+
debugger.start()
|
|
246
|
+
|
|
247
|
+
...
|
|
248
|
+
|
|
249
|
+
debugger.stop()
|
|
250
|
+
debugger.step()
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
### 3.4 在 NPU 侧执行训练脚本
|
|
254
|
+
|
|
255
|
+
配置完 NPU 侧训练脚本后即可执行训练脚本,命令示例如下:
|
|
256
|
+
|
|
257
|
+
```bash
|
|
258
|
+
bash train.sh
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
训练脚本执行完毕后,在GPU侧dump_path目录下生成比对结果文件,`accuracy_checking_result_{timestamp}_rank{rank_id}.csv`和`accuracy_checking_details_{timestamp}_rank{rank_id}.csv`记录两方比对结果,`api_precision_compare_result_{timestamp}_rank{rank_id}.csv`和`api_precision_compare_details_{timestamp}_rank{rank_id}.csv`记录三方比对结果。详细介绍请参见[离线精度预检中的 **4 预检结果**](./07.accuracy_checker_PyTorch.md#4-预检结果)。
|
|
262
|
+
|
|
263
|
+
## 4 支持的融合算子列表
|
|
264
|
+
|
|
265
|
+
预检工具当前支持的融合算子如下:
|
|
266
|
+
|
|
267
|
+
- npu_apply_adam_w
|
|
268
|
+
|
|
269
|
+
- npu_confusion_transpose
|
|
270
|
+
|
|
271
|
+
- fast_gelu
|
|
272
|
+
|
|
273
|
+
- npu_layer_norm_eval
|
|
274
|
+
|
|
275
|
+
- npu_linear
|
|
276
|
+
|
|
277
|
+
- npu_fusion_attention(该算子在 GPU 上预检时,需要额外安装 flash_attn,请用户自行安装。)
|
|
278
|
+
|
|
279
|
+
- npu_rms_norm
|
|
280
|
+
|
|
281
|
+
- npu_rotary_mul
|
|
282
|
+
|
|
283
|
+
- npu_scaled_masked_softmax
|
|
284
|
+
|
|
285
|
+
- npu_swiglu
|
|
286
|
+
|
|
287
|
+
- npu_apply_adam
|
|
288
|
+
|
|
289
|
+
- npu_group_norm_silu
|
|
290
|
+
|
|
291
|
+
- npu_mish
|
|
292
|
+
|
|
293
|
+
- npu_moe_gating_top_k_softmax
|
|
294
|
+
|
|
295
|
+
- npu_sort_v2
|
|
@@ -87,7 +87,7 @@ D-->config.json配置
|
|
|
87
87
|
<tr><td>scope</td><td>否</td><td>自定义</td><td>需要通过指定算子名来限制算子插桩范围 如:["Torch.matmul.0.forward", "Tensor.pow.4.forward"]。</td></tr>
|
|
88
88
|
<tr><td>list</td><td>否</td><td>自定义</td><td>需要通过指定算子类型来限制算子插桩范围 如:["relu"] 会匹配所有算子名中包含relu的算子。</td></tr>
|
|
89
89
|
<tr><td rowspan="2">fuzz_stage</td><td rowspan="2">否</td><td>"forward"(默认)</td><td>需要进行算子<b>前向</b>计算的精度问题排查或<b>验证可疑算子。</b></td></tr>
|
|
90
|
-
<tr><td>"backward"</td><td>需要进行算子<b>反向</b
|
|
90
|
+
<tr><td>"backward"</td><td>需要进行算子<b>反向</b>计算的精度问题排查,不支持仅反向验证,前向验证包括反向。必须设置list参数指定需要检测的算子(详见3.2 config.json配置章节), 指定的算子会暂存前向激活值,增加内存的占用。</td><td></td></tr>
|
|
91
91
|
</table>
|
|
92
92
|
|
|
93
93
|
#### 3.2.2 选择扰动因子
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
| [数据采集<br>(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析<br> 2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API<br>2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失<br>3、当前对inplace操作API或Module的支持度有限<br>4、暂不支持参数及参数梯度的采集 |
|
|
8
8
|
| [离线预检<br>(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查<br>2、精度排查不受模型累计误差影响 | 1、依赖GPU环境<br>2、不支持通信算子<br>3、仅支持部分融合算子 |
|
|
9
9
|
| [整网比对<br>(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响<br>2、当模型规模较大时,比对所需时间较长 |
|
|
10
|
+
| [在线预检<br>(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信<br>2、重计算模式下,不支持反向aten算子预检 |
|
|
10
11
|
| [溢出检查<br>(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module<br>2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 |
|
|
11
12
|
| [数据解析<br>(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU |
|
|
12
13
|
| [无标杆比对<br>(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查<br>2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用<br>2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 |
|
|
@@ -6,9 +6,9 @@
|
|
|
6
6
|
|
|
7
7
|
| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | 加工具并使能 Md5 Dump (耗时) |
|
|
8
8
|
|:--------:|:--------:|:-------------------:|:--------------------:|:--------------------:|
|
|
9
|
-
| L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3
|
|
10
|
-
| L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0
|
|
11
|
-
| mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3
|
|
9
|
+
| L0 | ≈95.1 ms | ≈95.5 ms (无膨胀) | ≈420.0 ms (膨胀4.5倍) | ≈1011.3 s (膨胀10倍) |
|
|
10
|
+
| L1 | ≈95.1 ms | ≈115.8 ms (膨胀1.2倍) | ≈2469.0 ms (膨胀26倍) | ≈8636.0 s (膨胀90倍) |
|
|
11
|
+
| mix | ≈95.1 ms | ≈117.8 ms (膨胀1.2倍) | ≈3635.4 ms (膨胀38 倍) | ≈10698.3 s (膨胀112倍) |
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
## "tensor"模式采集数据量参考基线
|
|
@@ -26,8 +26,7 @@ def read_npy_data(dir_path, file_name):
|
|
|
26
26
|
return None
|
|
27
27
|
|
|
28
28
|
data_path = os.path.join(dir_path, file_name)
|
|
29
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
30
|
-
FileCheckConst.NUMPY_SUFFIX, False)
|
|
29
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX)
|
|
31
30
|
data_path = path_checker.common_check()
|
|
32
31
|
data_value = load_npy(data_path)
|
|
33
32
|
return data_value
|
msprobe/msprobe.py
CHANGED
|
@@ -14,16 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import argparse
|
|
17
|
-
import sys
|
|
18
17
|
import importlib.util
|
|
18
|
+
import sys
|
|
19
19
|
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.common.file_utils import root_privilege_warning
|
|
21
22
|
from msprobe.core.common.log import logger
|
|
22
|
-
from msprobe.core.compare.utils import _compare_parser
|
|
23
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
24
24
|
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
25
|
-
from msprobe.core.
|
|
26
|
-
|
|
25
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
26
|
+
from msprobe.core.config_check.config_check_cli import _config_checking_parser, _run_config_checking_command
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def is_module_available(module_name):
|
|
@@ -64,6 +64,8 @@ def main():
|
|
|
64
64
|
if len(sys.argv) < 4:
|
|
65
65
|
parser.print_help()
|
|
66
66
|
sys.exit(0)
|
|
67
|
+
|
|
68
|
+
root_privilege_warning()
|
|
67
69
|
framework_args = parser.parse_args(sys.argv[1:3])
|
|
68
70
|
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
69
71
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
@@ -24,7 +24,8 @@ from msprobe.pytorch.pt_config import RunUTConfig
|
|
|
24
24
|
|
|
25
25
|
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
26
|
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
-
'black_list', 'error_data_path'])
|
|
27
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
28
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
class Config:
|
|
@@ -45,7 +46,13 @@ class Config:
|
|
|
45
46
|
'white_list': list,
|
|
46
47
|
'black_list': list,
|
|
47
48
|
'error_data_path': str,
|
|
48
|
-
'precision': int
|
|
49
|
+
'precision': int,
|
|
50
|
+
'is_online': bool,
|
|
51
|
+
'nfs_path': str,
|
|
52
|
+
'host': str,
|
|
53
|
+
'port': int,
|
|
54
|
+
'rank_list': list,
|
|
55
|
+
'tls_path': str
|
|
49
56
|
}
|
|
50
57
|
if key not in validators:
|
|
51
58
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
@@ -61,6 +68,10 @@ class Config:
|
|
|
61
68
|
RunUTConfig.check_filter_list_config(key, value)
|
|
62
69
|
if key == 'error_data_path':
|
|
63
70
|
RunUTConfig.check_error_data_path_config(value)
|
|
71
|
+
if key == 'nfs_path':
|
|
72
|
+
RunUTConfig.check_nfs_path_config(value)
|
|
73
|
+
if key == 'tls_path':
|
|
74
|
+
RunUTConfig.check_tls_path_config(value)
|
|
64
75
|
return value
|
|
65
76
|
|
|
66
77
|
|
|
@@ -74,6 +85,12 @@ class CheckerConfig:
|
|
|
74
85
|
self.white_list = msCheckerConfig.white_list
|
|
75
86
|
self.black_list = msCheckerConfig.black_list
|
|
76
87
|
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
+
self.is_online = msCheckerConfig.is_online
|
|
89
|
+
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
+
self.host = msCheckerConfig.host
|
|
91
|
+
self.port = msCheckerConfig.port
|
|
92
|
+
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
+
self.tls_path = msCheckerConfig.tls_path
|
|
77
94
|
|
|
78
95
|
if task_config:
|
|
79
96
|
self.load_config(task_config)
|
|
@@ -82,7 +99,22 @@ class CheckerConfig:
|
|
|
82
99
|
self.white_list = task_config.white_list
|
|
83
100
|
self.black_list = task_config.black_list
|
|
84
101
|
self.error_data_path = task_config.error_data_path
|
|
102
|
+
self.is_online = task_config.is_online
|
|
103
|
+
self.nfs_path = task_config.nfs_path
|
|
104
|
+
self.host = task_config.host
|
|
105
|
+
self.port = task_config.port
|
|
106
|
+
self.rank_list = task_config.rank_list
|
|
107
|
+
self.tls_path = task_config.tls_path
|
|
85
108
|
|
|
109
|
+
def get_online_config(self):
|
|
110
|
+
return OnlineConfig(
|
|
111
|
+
is_online=self.is_online,
|
|
112
|
+
nfs_path=self.nfs_path,
|
|
113
|
+
host=self.host,
|
|
114
|
+
port=self.port,
|
|
115
|
+
rank_list=self.rank_list,
|
|
116
|
+
tls_path=self.tls_path
|
|
117
|
+
)
|
|
86
118
|
|
|
87
119
|
def get_run_ut_config(self, **config_params):
|
|
88
120
|
return RunUtConfig(
|
|
@@ -95,5 +127,6 @@ class CheckerConfig:
|
|
|
95
127
|
real_data_path=config_params.get('real_data_path'),
|
|
96
128
|
white_list=self.white_list.copy() if self.white_list else [],
|
|
97
129
|
black_list=self.black_list.copy() if self.black_list else [],
|
|
98
|
-
error_data_path=config_params.get('error_data_path')
|
|
130
|
+
error_data_path=config_params.get('error_data_path'),
|
|
131
|
+
online_config=self.get_online_config()
|
|
99
132
|
)
|
|
@@ -117,6 +117,30 @@ def api_precision_compare(config):
|
|
|
117
117
|
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
def online_api_precision_compare(online_config):
|
|
121
|
+
rank = online_config.rank
|
|
122
|
+
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
|
|
123
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
124
|
+
details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
|
|
125
|
+
"_rank*.csv", f"_rank{rank}.csv")
|
|
126
|
+
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
127
|
+
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
128
|
+
if not os.path.exists(result_csv_path):
|
|
129
|
+
write_csv(result_csv_title, result_csv_path)
|
|
130
|
+
if not os.path.exists(details_csv_path):
|
|
131
|
+
write_csv(detail_csv_title, details_csv_path)
|
|
132
|
+
config = CompareConfig("", "", result_csv_path, details_csv_path)
|
|
133
|
+
try:
|
|
134
|
+
npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
|
|
135
|
+
check_csv_columns(npu_data.columns, "npu_csv")
|
|
136
|
+
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
137
|
+
analyse_csv(npu_data, gpu_data, config)
|
|
138
|
+
except Exception as err:
|
|
139
|
+
logger.error(f"Online api precision compare Error: {str(err)}")
|
|
140
|
+
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
141
|
+
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
142
|
+
|
|
143
|
+
|
|
120
144
|
def analyse_csv(npu_data, gpu_data, config):
|
|
121
145
|
forward_status, backward_status = [], []
|
|
122
146
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
@@ -66,6 +66,13 @@ class Comparator:
|
|
|
66
66
|
self.save_path_list = [result_csv_path]
|
|
67
67
|
self.detail_save_path_list = [details_csv_path]
|
|
68
68
|
|
|
69
|
+
if config and config.online_config.is_online:
|
|
70
|
+
self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
|
|
71
|
+
self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
|
|
72
|
+
self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
73
|
+
self.detail_save_path_list = \
|
|
74
|
+
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
75
|
+
|
|
69
76
|
self.registry = self._register_compare_func()
|
|
70
77
|
|
|
71
78
|
if not is_continue_run_ut:
|
|
@@ -238,8 +245,9 @@ class Comparator:
|
|
|
238
245
|
self.write_detail_csv(args)
|
|
239
246
|
|
|
240
247
|
|
|
241
|
-
def compare_output(self, full_api_name, data_info):
|
|
248
|
+
def compare_output(self, full_api_name, data_info, is_online=False):
|
|
242
249
|
"""Get compare result and write to result and detail csv.
|
|
250
|
+
is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
|
|
243
251
|
"""
|
|
244
252
|
_, api_name = extract_basic_api_segments(full_api_name)
|
|
245
253
|
if not api_name:
|
|
@@ -272,7 +280,9 @@ class Comparator:
|
|
|
272
280
|
fwd_compare_alg_results,
|
|
273
281
|
bwd_compare_alg_results,
|
|
274
282
|
data_info.rank)
|
|
275
|
-
|
|
283
|
+
if is_online:
|
|
284
|
+
# get run_ut compare detail
|
|
285
|
+
return self._get_run_ut_detail(result_info)
|
|
276
286
|
self.record_results(result_info)
|
|
277
287
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
278
288
|
or bwd_success_status == CompareConst.SPACE
|