mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +37 -47
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  9. msprobe/core/hook_manager.py +2 -16
  10. msprobe/core/service.py +5 -16
  11. msprobe/docs/01.installation.md +2 -0
  12. msprobe/docs/02.config_introduction.md +0 -13
  13. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  14. msprobe/docs/21.visualization_PyTorch.md +1 -1
  15. msprobe/docs/25.tool_function_introduction.md +0 -1
  16. msprobe/docs/32.ckpt_compare.md +5 -5
  17. msprobe/mindspore/monitor/module_hook.py +17 -20
  18. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  19. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  20. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  21. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  22. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  23. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  24. msprobe/pytorch/common/utils.py +0 -70
  25. msprobe/pytorch/debugger/debugger_config.py +0 -10
  26. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  27. msprobe/pytorch/hook_module/api_register.py +5 -1
  28. msprobe/pytorch/monitor/module_hook.py +16 -34
  29. msprobe/pytorch/pt_config.py +2 -51
  30. msprobe/pytorch/pytorch_service.py +2 -11
  31. msprobe/visualization/builder/graph_builder.py +2 -2
  32. msprobe/visualization/builder/graph_merger.py +13 -0
  33. msprobe/visualization/graph/graph.py +13 -9
  34. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  35. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  36. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  37. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  38. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  39. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  40. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  41. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  43. msprobe/pytorch/attl_manager.py +0 -65
  44. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
  45. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
  46. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
  47. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
@@ -146,6 +146,7 @@ class BaseGraphMerger:
146
146
  GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS,
147
147
  id_accumulation=True)
148
148
  all_collection_node = main_graph_result.graph.get_node(all_collection_node_id)
149
+ all_collection_node.upnode = main_graph_result.graph.root
149
150
  new_main_root_sub_nodes.append(all_collection_node)
150
151
  # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0
151
152
  origin_main_node_id = main_node.id
@@ -377,6 +378,12 @@ class PPMerger(BaseGraphMerger):
377
378
  logger.info('Unable to get pp groups based on Distributed Api (batch_isend_irecv, send, or isend), '
378
379
  'generate pp groups using parallel param "rank_size", "tp" and "pp".')
379
380
  _, pp_groups = self.get_default_groups()
381
+ elif len(pp_groups[0]) != self.parallel_param.pp:
382
+ logger.warning(f'Based on Distributed Api (atch_isend_irecv, send, or isend), '
383
+ f'the resulting pp groups={pp_groups}, '
384
+ f'its length is not equal to the parallel param "pp"({self.parallel_param.pp}) you defined, '
385
+ f'generate pp groups using parallel param "rank_size", "tp" and "pp".')
386
+ _, pp_groups = self.get_default_groups()
380
387
  logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.')
381
388
  return pp_groups
382
389
 
@@ -657,6 +664,12 @@ class TPMerger(BaseGraphMerger):
657
664
  logger.info('Unable to get tp groups based on Distributed Api (reduce_scatter or all_reduce), '
658
665
  'generate tp groups using parallel param "rank_size", "tp" and "pp".')
659
666
  tp_groups, _ = self.get_default_groups()
667
+ elif len(tp_groups[0]) != self.parallel_param.tp:
668
+ logger.warning(f'Based on Distributed Api (reduce_scatter or all_reduce), '
669
+ f'the resulting tp groups={tp_groups}, '
670
+ f'its length is not equal to the parallel param "tp"({self.parallel_param.tp}) you defined, '
671
+ f'generate tp groups using parallel param "rank_size", "tp" and "pp".')
672
+ tp_groups, _ = self.get_default_groups()
660
673
  logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.')
661
674
  return tp_groups
662
675
 
@@ -126,21 +126,25 @@ class Graph:
126
126
 
127
127
  def get_sorted_nodes(self):
128
128
  """
129
- 通过深度优先遍历graph,获得排过序的node列表
129
+ 通过深度优先遍历graph,获得排过序的node列表,使用栈实现避免超出递归深度问题
130
130
  """
131
131
  visited = set()
132
132
  order = []
133
+ stack = [(self.root, False)]
133
134
 
134
- @recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500)
135
- def visit(node):
135
+ while stack:
136
+ node, processed = stack.pop()
136
137
  if node.id in visited:
137
- return
138
- visited.add(node.id)
139
- for sub_node in node.subnodes:
140
- visit(sub_node)
141
- order.append(node)
138
+ continue
139
+ if processed:
140
+ visited.add(node.id)
141
+ order.append(node)
142
+ else:
143
+ stack.append((node, True))
144
+ for sub_node in reversed(node.subnodes):
145
+ if sub_node.id not in visited:
146
+ stack.append((sub_node, False))
142
147
 
143
- visit(self.root)
144
148
  return order
145
149
 
146
150
  def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
@@ -1,295 +0,0 @@
1
- # PyTorch 场景的在线精度预检
2
-
3
- ## 1 简介
4
-
5
- 为了应对大模型场景下,通过离线预检方式 dump API 输入输出数据导致的存储资源紧张问题,提供在线精度预检功能。本功能实现在执行 NPU 训练操作的过程中,通过 TCP/IP 协议在 NPU
6
- Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据在 GPU 设备上运行,将两份输出数据进行比对,得到预检比对结果,从而减少数据 dump 的步骤,降低存储资源的占用。针对偏差较大的算子,两方比对(NPU vs. GPU)的方法缺少裁判进行裁定。 参考离线预检,在线预检场景同时支持两方比对和三方比对方式,按照 api 的精度标准要求,选择两方比对或三方比对。
7
-
8
- ## 2 在线精度预检流程
9
-
10
- 在线精度预检当前支持**局域网场景**和**共享存储场景**,请根据不同的场景选择对应的配置。
11
-
12
- 在线精度预检操作流程如下:
13
-
14
- 1. 准备 GPU 和 NPU 可正常运行的训练环境,PyTorch 版本大于等于2.0,并保证两台 Host 在同一局域网内可正常通信或能通过共享存储进行通信。
15
- 2. GPU 和 NPU Host 设备上同时安装msprobe工具,详见[ msprobe 安装](./01.installation.md)章节,其中在线预检要安装 twisted、pyOpenSSL,这些包为 Python 模块。
16
- 3. 分别配置 GPU 侧、NPU 侧的 config.json 文件。
17
- 4. 在 GPU 侧运行 `msprobe -f pytorch run_ut -config ./config.json`。
18
- 5. 在 NPU 侧配置训练脚本。
19
- 6. 在 NPU 侧执行训练。
20
-
21
- ## 3 在线精度预检操作指导
22
-
23
- ### 3.1 配置 config.json 文件
24
-
25
- 预检工具安装完成后,需要在 GPU 和 NPU 环境下分别配置 config.json。其中需要重点关注文件中的 is_online、is_benchmark_device、host 和 port 参数的配置,保障在线预检时 GPU 和 NPU 两台设备间的通信正常。
26
-
27
- #### 3.1.1 GPU 侧在线预检配置说明
28
-
29
- | 参数名称 | 说明 | 是否必选 |
30
- |-----------------|--------------|------|
31
- | task | 任务名称,str 类型,配置为 run_ut 表示预检任务。通过其他字段 is_online 判断离线预检、在线预检任务。 | 是 |
32
- | white_list | 预检的 API 白名单,list[str] 类型。<br/>**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置白名单,即预检全量 API 数据。 | 否 |
33
- | black_list | 预检的 API 黑名单,list[str] 类型。<br/>**配置示例**:white_list=["conv1d", "conv2d"]。默认未配置黑名单,即预检全量 API 数据。 | 否 |
34
- | error_data_path | 配置保存精度未达标的 API 输入输出数据路径,str 类型。在线预检模式下该参数不生效。 | 否 |
35
- | is_online | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
36
- | nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host、port 和 tls_path 不生效。 | 否 |
37
- | host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
38
- | port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
39
- | rank_list | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。 | 是 |
40
- | tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 server.key、证书 server.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
41
-
42
-
43
- #### 3.1.2 NPU 侧在线预检配置说明
44
-
45
- | 参数名称 | 说明 | 是否必选 |
46
- |------------------|-------------|------|
47
- | task | 任务名称,str 类型,配置为 tensor 表示 dump API 统计信息和完全复刻整网的 API 运行情况的真实数据。通过字段 online_run_ut 判断是否使用在线预检功能。 | 是 |
48
- | dump_path | dump 路径,str 类型,配置为合法路径即可,兼容 tensor 任务静态检查。 | 是 |
49
- | level | dump 级别,str 类型,在线预检时配置为 L1,表示 dump API 级精度数据。在线预检可不配置,默认取值 L1。 | 是 |
50
- | rank | 指定对某张卡上的数据进行 dump,list[int] 类型,默认未配置(表示 dump所有卡的数据),需要与 GPU 侧配置项 rank_list 保持一致。 | 否 |
51
- | step | 指定 dump 某个 step 的数据,list[int] 类型,默认未配置,表示 dump 所有 step 的数据。dump 特定 step 时,须指定为训练脚本中存在的 step。 | 否 |
52
- | scope | dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
53
- | list | dump 范围,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 api 的数据),配置方式参考 [config.json 配置介绍](./02.config_introduction.md)。 | 否 |
54
- | online_run_ut | 在线预检模式开关,bool 类型,可取值 True(开启)、False(关闭),默认关闭。 | 是 |
55
- | nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效。 | 否 |
56
- | host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
57
- | port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
58
- | tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 client.key、证书 client.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 |
59
- | online_run_ut_recompute | 模型训练是否使用重计算机制,bool类型,默认为False,表示模型没有使用重计算。在线预检暂不支持重计算机制下反向算子的预检,当模型训练使用重计算时,跳过反向算子预检,默认模型关闭重计算。 | 否 |
60
-
61
- #### 3.1.3 局域网场景配置示例
62
-
63
- 若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法。
64
-
65
- 以下秘钥生成方法仅为简单示例,客户应使用与自己需求相符的秘钥生成和存储机制并保证秘钥安全性与机密性,必要时可采用分层秘钥机制。
66
- 以下示例中加密口令仅供参考,使用时请更换为复杂口令,并保护口令安全。
67
- ```shell
68
- # 生成CA证书的根私钥和证书签名请求,其中ca_password为CA私钥加密口令,仅作演示,请更换使用
69
- openssl req -new -newkey rsa:3072 -passout pass:ca_password -subj "/CN=*ca.com/O=ca.Inc./C=CN/ST=Zhejiang/L=Hangzhou" -keyout ca.key -out ca.csr
70
- # 自签发根证书
71
- openssl x509 -req -days 365 -in ca.csr -signkey ca.key -passin pass:ca_password -out ca.crt -extensions v3_ca -extfile <(cat <<-EOF
72
- [v3_ca]
73
- basicConstraints = critical,CA:true
74
- keyUsage = critical, keyCertSign, cRLSign
75
- EOF
76
- )
77
-
78
- # 生成client公私钥,其中client_password为私钥加密口令,仅作演示,请更换使用
79
- openssl genrsa -aes256 -passout pass:client_password -out client.key 3072
80
- # 基于client公私钥生成签名请求
81
- openssl req -new -key client.key -passin pass:client_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out client.csr
82
- # 利用自签发的根证书,签发client证书
83
- openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in client.csr -out client.crt -CAcreateserial -extfile <(cat <<-EOF
84
- [v3_server]
85
- basicConstraints = CA:FALSE
86
- keyUsage = critical, digitalSignature, keyEncipherment
87
- extendedKeyUsage = serverAuth
88
- EOF
89
- )
90
-
91
- # 生成server公私钥,其中server_password为私钥加密口令,仅作演示,请更换使用
92
- openssl genrsa -aes256 -passout pass:server_password -out server.key 3072
93
- # 基于server公私钥生成签名请求
94
- openssl req -new -key server.key -passin pass:server_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out server.csr
95
- # 利用自签发的根证书,签发server证书
96
- openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in server.csr -out server.crt -CAcreateserial -extfile <(cat <<-EOF
97
- [v3_server]
98
- basicConstraints = CA:FALSE
99
- keyUsage = critical, digitalSignature, keyEncipherment
100
- extendedKeyUsage = serverAuth
101
- EOF
102
- )
103
-
104
- ```
105
-
106
- 当需要吊销已创建的SSL证书时,通过openssl命令生成CRL证书 crl.pem,示例如下:
107
- ```shell
108
- # 创建证书信息的文本数据库,空文件即可
109
- touch index.txt
110
-
111
- # 创建ca配置文件ca.cnf,内容如下,用于吊销证书使用
112
- [ca]
113
- default_ca = CA_default
114
- [CA_default]
115
- database = ./index.txt
116
- default_md = sha256
117
-
118
- # 吊销证书 client.crt,其中ca_password为CA私钥加密口令,与CA创建时保持一致
119
- openssl ca -revoke client.crt -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password
120
- # 生成CRL文件
121
- openssl ca -gencrl -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password -out crl.pem -crldays 30
122
- # 查看生成的CRL文件内容:
123
- openssl工具的命令: openssl crl -inform PEM -in crl.pem -text
124
-
125
- ```
126
-
127
- 注意:配置TLS协议时,传输性能受机器环境和网络质量的影响,可能触发NPU超时中断模型训练,为避免训练和预检中断,丢弃长时间未传输的api数据,同时NPU侧配置HCCL环境变量,配置方式如下:
128
-
129
- a) 调整HCCL环境变量,关闭看门狗,避免WorkHCCL超时中断模型训练:
130
- ```shell
131
- export HCCL_DESYNC_DEBUG=0
132
- export HCCL_ASYNC_ERROR_HANDLING=0
133
- ```
134
- b) 调整通信算子超时设置(以1800s举例):
135
- ```shell
136
- export HCCL_CONNECT_TIMEOUT=1800
137
- export HCCL_EXEC_TIMEOUT=1800
138
- ```
139
-
140
- GPU 侧:
141
-
142
- ```json
143
- {
144
- "task": "run_ut",
145
- "run_ut": {
146
- "white_list": [],
147
- "black_list": [],
148
- "error_data_path": "./",
149
- "is_online": true,
150
- "nfs_path": "",
151
- "host": "127.0.0.1",
152
- "port": 59208,
153
- "rank_list": [0],
154
- "tls_path": ""
155
- }
156
- }
157
- ```
158
-
159
- NPU 侧:
160
-
161
- ```json
162
- {
163
- "task": "tensor",
164
- "dump_path": "./dump_path",
165
- "rank": [0],
166
- "step": [0],
167
- "level": "L1",
168
- "tensor": {
169
- "scope": [],
170
- "list": [],
171
- "online_run_ut": true,
172
- "nfs_path": "",
173
- "host": "xx.xx.xx.x",
174
- "port": 59208,
175
- "tls_path": ""
176
- }
177
- }
178
- ```
179
-
180
- #### 3.1.4 共享存储场景配置示例
181
-
182
- GPU 侧:
183
-
184
- ```json
185
- {
186
- "task": "run_ut",
187
- "run_ut": {
188
- "white_list": [],
189
- "black_list": [],
190
- "error_data_path": "./",
191
- "is_online": true,
192
- "nfs_path": "/nfs/xxx/data",
193
- "host": "",
194
- "port": -1,
195
- "rank_list": [0],
196
- "tls_path": ""
197
- }
198
- }
199
- ```
200
-
201
- NPU 侧:
202
-
203
- ```json
204
- {
205
- "task": "tensor",
206
- "dump_path": "./dump_path",
207
- "rank": [0],
208
- "step": [0],
209
- "level": "L1",
210
- "tensor": {
211
- "scope": [],
212
- "list": [],
213
- "online_run_ut": true,
214
- "nfs_path": "/nfs/xxx/data",
215
- "host": "",
216
- "port": -1,
217
- "tls_path": ""
218
- }
219
- }
220
- ```
221
-
222
- ### 3.2 在 GPU 侧运行 run_ut
223
-
224
- 由于 GPU 侧为通信接收端,需先于 NPU 侧执行 run_ut 操作,命令如下:
225
-
226
- ```bash
227
- msprobe -f pytorch run_ut -config ./config.json
228
- ```
229
-
230
- GPU 侧配置好 config.json 文件后执行 run_ut 命令,此时 GPU 处于预检等待状态:
231
-
232
- - 局域网场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到 GPU 侧时,GPU 启动预检操作。
233
- - 共享存储场景:当 NPU 侧启动训练后将预检的 API 输入和输出数据发送到共享存储时,GPU 启动预检操作。
234
-
235
- ### 3.3 在 NPU 侧配置训练脚本
236
-
237
- 在 NPU 训练脚本中添加如下代码以获取 run_ut 操作的预检 API 输入和输出数据:
238
-
239
- ```python
240
- from msprobe.pytorch import PrecisionDebugger
241
-
242
- debugger = PrecisionDebugger("config.json")
243
- ...
244
-
245
- debugger.start()
246
-
247
- ...
248
-
249
- debugger.stop()
250
- debugger.step()
251
- ```
252
-
253
- ### 3.4 在 NPU 侧执行训练脚本
254
-
255
- 配置完 NPU 侧训练脚本后即可执行训练脚本,命令示例如下:
256
-
257
- ```bash
258
- bash train.sh
259
- ```
260
-
261
- 训练脚本执行完毕后,在GPU侧dump_path目录下生成比对结果文件,`accuracy_checking_result_{timestamp}_rank{rank_id}.csv`和`accuracy_checking_details_{timestamp}_rank{rank_id}.csv`记录两方比对结果,`api_precision_compare_result_{timestamp}_rank{rank_id}.csv`和`api_precision_compare_details_{timestamp}_rank{rank_id}.csv`记录三方比对结果。详细介绍请参见[离线精度预检中的 **4 预检结果**](./07.accuracy_checker_PyTorch.md#4-预检结果)。
262
-
263
- ## 4 支持的融合算子列表
264
-
265
- 预检工具当前支持的融合算子如下:
266
-
267
- - npu_apply_adam_w
268
-
269
- - npu_confusion_transpose
270
-
271
- - fast_gelu
272
-
273
- - npu_layer_norm_eval
274
-
275
- - npu_linear
276
-
277
- - npu_fusion_attention(该算子在 GPU 上预检时,需要额外安装 flash_attn,请用户自行安装。)
278
-
279
- - npu_rms_norm
280
-
281
- - npu_rotary_mul
282
-
283
- - npu_scaled_masked_softmax
284
-
285
- - npu_swiglu
286
-
287
- - npu_apply_adam
288
-
289
- - npu_group_norm_silu
290
-
291
- - npu_mish
292
-
293
- - npu_moe_gating_top_k_softmax
294
-
295
- - npu_sort_v2
@@ -1,205 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- import os.path
18
- import time
19
- from multiprocessing import Queue
20
- from typing import Optional, Union, Dict, Any
21
- from dataclasses import dataclass
22
-
23
- import torch
24
-
25
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
26
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
27
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
28
- from msprobe.core.common.file_utils import remove_path
29
- from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
30
- from msprobe.core.common.decorator import recursion_depth_decorator
31
-
32
- BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
33
-
34
-
35
- @dataclass
36
- class ATTLConfig:
37
- is_benchmark_device: bool
38
- connect_ip: str
39
- connect_port: int
40
- # storage_config
41
- nfs_path: str = None
42
- tls_path: str = None
43
- check_sum: bool = True
44
- queue_size: int = 50
45
-
46
-
47
- class ATTL:
48
- def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
49
- self.session_id = session_id
50
- self.session_config = session_config
51
- self.logger = logger
52
- self.socket_manager = None
53
- self.data_queue = Queue(maxsize=50)
54
- self.dequeue_list = []
55
- self.message_end = False
56
- self.kill_progress = False
57
- self.nfs_path = None
58
- if self.session_config.nfs_path:
59
- self.nfs_path = self.session_config.nfs_path
60
- elif self.session_config.is_benchmark_device:
61
-
62
- self.socket_manager = TCPServer(self.session_config.connect_port,
63
- self.data_queue,
64
- self.session_config.check_sum,
65
- self.session_config.tls_path)
66
- self.socket_manager.start()
67
- elif need_dump:
68
- self.socket_manager = TCPClient(self.session_config.connect_ip,
69
- self.session_config.connect_port,
70
- self.session_config.check_sum,
71
- self.session_config.tls_path)
72
- self.socket_manager.start()
73
-
74
- def stop_serve(self):
75
- if isinstance(self.socket_manager, TCPServer):
76
- self.socket_manager.stop()
77
-
78
- def send(self, buffer: BufferType) -> None:
79
- """
80
- npu major in 'send' (client)
81
- """
82
-
83
- # if tcp connection lost,
84
- if self.socket_manager.signal_exit:
85
- raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
86
-
87
- # know receiver receive and go next
88
- if isinstance(buffer, ApiData):
89
- buffer = move2target_device(buffer, torch.device('cpu'))
90
-
91
- if 'device' in buffer.kwargs:
92
- buffer.kwargs.pop('device')
93
- rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
94
- step = buffer.step if hasattr(buffer, "step") else 0
95
- try:
96
- io_buff = save_api_data(buffer)
97
- except Exception as e:
98
- self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
99
- return
100
- data = io_buff.getvalue()
101
- self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
102
-
103
- def recv(self, timeout_ms=0) -> Optional[BufferType]:
104
- buffer = ''
105
- while not buffer:
106
- if timeout_ms > 0:
107
- time.sleep(timeout_ms / 1000.0)
108
- if not buffer and not self.data_queue.empty():
109
- buffer = self.data_queue.get()
110
- break
111
- if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
112
- break
113
- if self.message_end and self.data_queue.empty():
114
- buffer = b"KILL_CONFIRM"
115
- self.kill_progress = True
116
- break
117
- time.sleep(0.1) # waiting outside the lock before next attempt
118
- if not buffer:
119
- # this is a result of a timeout
120
- self.logger.info(f"RECEIVE API DATA TIMED OUT")
121
- else:
122
- if buffer == b"STOP_":
123
- return "STOP_"
124
- if buffer == b"KILL_":
125
- self.message_end = True
126
- return "STOP_"
127
- if buffer == b"KILL_CONFIRM":
128
- self.kill_progress = True
129
- return "KILL_"
130
- try:
131
- buffer = load_api_data(buffer)
132
- except Exception as e:
133
- self.logger.warning("there is something error. please check it. %s", e)
134
- if isinstance(buffer, bytes):
135
- return ''
136
- if isinstance(buffer, str):
137
- return buffer
138
-
139
- return buffer
140
-
141
- def upload(self, buffer: BufferType):
142
- if isinstance(buffer, ApiData):
143
- buffer = move2target_device(buffer, torch.device('cpu'))
144
- file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
145
- else:
146
- file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
147
-
148
- try:
149
- save_pkl(buffer, file_path)
150
- except Exception as e:
151
- self.logger.warning("there is something error in save_pt. please check it. %s", e)
152
-
153
- def download(self):
154
- buffer = None
155
- cur_file = None
156
- for file_type in ("start*", "*.pt", "end*"):
157
- pattern = os.path.join(self.nfs_path, file_type)
158
- files = glob.glob(pattern)
159
- if len(files) > 0:
160
- cur_file = files[0]
161
- break
162
-
163
- if cur_file is not None:
164
- try:
165
- buffer = load_pkl(cur_file)
166
- except Exception as e:
167
- self.logger.warning("there is something error. please check it. %s", e)
168
- remove_path(cur_file)
169
- return buffer
170
-
171
-
172
- @recursion_depth_decorator("move2device_exec")
173
- def move2device_exec(obj, device):
174
- if isinstance(obj, (tuple, list)):
175
- data_list = [move2device_exec(val, device) for val in obj]
176
- return data_list if isinstance(obj, list) else tuple(data_list)
177
- if isinstance(obj, dict):
178
- return {key: move2device_exec(val, device) for key, val in obj.items()}
179
- elif isinstance(obj, torch.Tensor):
180
- obj = obj.detach()
181
- if obj.device.type != device:
182
- obj = obj.to(device)
183
- return obj
184
- elif "return_types" in str(type(obj)):
185
- return move2device_exec(tuple(obj), device)
186
- elif isinstance(obj, torch._C.device):
187
- return torch.device(device)
188
- else:
189
- return obj
190
-
191
-
192
- def move2target_device(buffer: ApiData, target_device):
193
- # handle args
194
- new_args = move2device_exec(buffer.args, target_device)
195
-
196
- # handle kwargs
197
- new_kwargs = move2device_exec(buffer.kwargs, target_device)
198
-
199
- # handle result
200
- new_results = move2device_exec(buffer.result, target_device)
201
-
202
- if target_device == torch.device('cpu') or target_device == "cpu":
203
- return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
204
- else:
205
- return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)