mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,9 @@
21
21
  ```json
22
22
  {
23
23
  "task": "grad_probe",
24
- "dump_path": "./dump_path",
25
- "rank": [],
26
- "step": [],
24
+ "dump_path": "./dump_path",
25
+ "rank": [],
26
+ "step": [],
27
27
  "grad_probe": {
28
28
  "grad_level": "L1",
29
29
  "param_list": [],
@@ -43,7 +43,7 @@
43
43
  | step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step的数据。默认为空。(MindSpore静态图模式下,当前暂不支持指定step功能) | List[int] | 否 |
44
44
  | grad_level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2。默认L1。|str | 否 |
45
45
  | param_list | 权重名称列表,表示需要监控的权重。列表为空就表示监控所有权重。默认为空。 | List[str] | 否 |
46
- | bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列。可以使用默认值[-1, 0, 1]。 | List[float, int] | 否 |
46
+ | bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列,并且列表中的元素需要在int64取值范围内。可以使用默认值[-1, 0, 1]。 | List[float, int] | 否 |
47
47
 
48
48
  **不同级别的level的导出数据**
49
49
 
@@ -53,29 +53,29 @@
53
53
  | L0 | ("param_name", "MD5", "max", "min", "norm", "shape") | 否 |
54
54
  | L1 | ("param_name", "max", "min", "norm", "shape") | 是 |
55
55
  | L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 |
56
-
56
+
57
57
  intervals就是根据值分布bounds划分出的区间。
58
58
  MindSpore静态图模式下,L0级别中暂不支持"MD5"
59
-
59
+
60
60
  **方向数据解释**
61
-
61
+
62
62
  因为模型的参数往往非常大,所以存储真实数据是不可接受的,这里折衷一下,只存储梯度数据的正负号(一个布尔值),也就是方向。
63
-
63
+
64
64
  **bounds和值分布解释**
65
-
65
+
66
66
  + 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。
67
- + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示:
67
+ + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示:
68
68
  ![Alt text](./img/grad_probe_image-1.png)
69
69
 
70
70
  2. 插入代码。示例代码如下:
71
71
 
72
- - PyTorch框架:模型构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`模型`作为参数传入。
72
+ - PyTorch框架:模型构造完成后,传入config.json的路径实例化一个PrecisionDebugger对象,然后调用debugger.monitor并将`模型`作为参数传入。
73
73
  ```python
74
74
  from msprobe.pytorch import PrecisionDebugger
75
75
  debugger = PrecisionDebugger("config_json_path")
76
76
  debugger.monitor(model)
77
77
  ```
78
- - MindSpore框架:优化器构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`优化器`作为参数传入。
78
+ - MindSpore框架:优化器构造完成后,传入config.json的路径实例化一个PrecisionDebugger对象,然后调用debugger.monitor并将`优化器`作为参数传入。
79
79
  ```python
80
80
  from msprobe.mindspore import PrecisionDebugger
81
81
  debugger = PrecisionDebugger("config_json_path")
@@ -143,7 +143,7 @@ GradComparator.compare_distributed("配置文件里写的dump_path",
143
143
  "配置文件里写的dump_path",
144
144
  "比对结果输出目录")
145
145
  ```
146
-
146
+
147
147
 
148
148
  ### 比对结果
149
149
 
@@ -190,6 +190,7 @@ PrecisionDebugger.monitor(module)
190
190
  | ----- | -------------------- | -------- |
191
191
  | module |Pytorch框架下传入模型,必须是torch.nn.Module;MindSpore框架下传入优化器。 | 是 |
192
192
 
193
+ Pytorch场景,传入的模型不能被torch.jit.trace修饰。MindSpore动态图场景,传入的优化器不能被mindspore.jit修饰。
193
194
 
194
195
  **接口说明**
195
196
 
@@ -202,6 +203,3 @@ GradComparator.compare_distributed(dump_path1, dump_path2, output_path)
202
203
  | dump_path1 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path。 | str | 是 |
203
204
  | dump_path2 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path,与dump_path1可以互换。 | str | 是 |
204
205
  | output_path |输出结果目录,不存在会新建。 | str | 是 |
205
-
206
-
207
- # FAQ
@@ -0,0 +1,89 @@
1
+ # **PyTorch NPU在线精度比对工具使用指南**
2
+
3
+ PyTorch NPU在线精度比对是msprobe工具实现在PyTorch训练过程中直接完成精度比对并输出比对结果的功能。
4
+
5
+ 在线精度比对实现的是NPU与CPU之间的精度比对。
6
+
7
+ ## PyTorch NPU在线精度比对总体流程
8
+
9
+ 1. 准备NPU训练工程。
10
+
11
+ 2. 在NPU环境下安装msprobe工具。
12
+
13
+ 3. 在训练脚本内插入msprobe工具在线精度比对接口。
14
+
15
+ 4. 执行训练并获取在线精度比对NPU和CPU分别执行后的精度比对结果。
16
+
17
+ 5. 比对结果分析。
18
+
19
+ ## PyTorch NPU在线精度比对
20
+ ### 总体说明
21
+ - 本节主要介绍NPU精度比对所需要的函数以及示例。
22
+ - 在线精度比对工具通过截获PyTorch框架中部分Aten Ir及其输入输出,并将输入数据转到CPU执行,最后将NPU和CPU的执行结果进行精度比对得到比对结果。
23
+
24
+ ### 约束
25
+
26
+ - Pytorch 只支持2.0及其以上版本。
27
+ - 只支持Aten Ir级在线精度比对,所有Aten Ir可以通过dir(torch.ops.aten)查看,其中部分IR不支持在线比对:Aten Ir无对应CPU实现、NPU和CPU同AtenIR实现逻辑不一致,导致同输入不同输出。
28
+ - 正反向不支持同时在线精度比对,不支持跨step在线精度比对。
29
+
30
+
31
+ ### 场景示例
32
+ 1. 在NPU训练脚本中添加在线精度比对接口,示例如下:
33
+
34
+ ```python
35
+ from msprobe.pytorch.common import seed_all
36
+ from msprobe.pytorch.online_dispatch import PtdbgDispatch
37
+
38
+ # 在main函数开始前固定随机数
39
+ seed_all()
40
+
41
+
42
+ ...
43
+
44
+ # 在需要调试精度的正向或反向代码前设置
45
+ # 正向示例
46
+ with PtdbgDispatch(dump_mode="auto", dump_path="/home/dump"):
47
+ output = model_cpu(inputs)
48
+ # 反向示例
49
+ with PtdbgDispatch(dump_mode="auto", dump_path="/home/dump"):
50
+ loss.backward()
51
+ ```
52
+
53
+ 2. 执行训练。
54
+
55
+ 3. 找出精度不达标的Aten IR。
56
+
57
+ 执行过程中会打屏Failed,Failed在比对结果csv中的Accuracy Reached or Not列标记为No,并在Dump目录下存盘精度不达标Aten IR的输入输出。
58
+
59
+ ### 计算精度评价指标
60
+
61
+ 1. Cosine < 0.99 且 MaxAbsError > 0.001时,精度不达标;
62
+ 2. Cosine < 0.9,精度不达标;
63
+ 3. MaxAbsError > 1,精度不达标。
64
+
65
+ ### 在线精度比对参数设置说明
66
+
67
+ | 参数名称 | 说明 | 是否必选 |
68
+ | -------- |-------------------------------------------------------------------------------------------------| -------- |
69
+ | dump_mode| dump模式,可取值"all"、"list"、"auto"、"OFF",默认值为OFF(表示不Dump数据)。 | 否 |
70
+ | api_list | dump范围,dump_mode="list"时设置,需要Dump Aten Ir API名称,默认为None,Aten Ir API名称可以通过dir(torch.ops.aten)查看。 | 否 |
71
+ | dump_path| dump文件生成的路径。 | 是 |
72
+ | tag | 传入tag字符串,成为dump文件夹名一部分,默认为None。 | 否 |
73
+ | process_num | 多进程并发数,默认为0。 | 否 |
74
+ | debug | debug信息打印,默认为False。 | 否 |
75
+ ### dump数据存盘说明
76
+ dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。
77
+
78
+ 子目录下包含1个比对结果csv文件、cpu和npudump数据目录,npu目录下包含Aten IR在NPU上的输入输出的dump数据,由于CPU的输入是直接使用NPU的输入执行,因此cpu目录下只包含执行输出的dump数据。
79
+
80
+ ```bash
81
+ atat_rank4_20230911170521
82
+ ├── compare_result_rank4_20230911170521.csv
83
+ ├── cpu
84
+ │ ├── native_batch_norm_backward_10_output.0.npy
85
+ │ ............
86
+ └── npu
87
+ ├── native_batch_norm_backward_10_input.0.npy
88
+ ............
89
+ ```
@@ -1,4 +1,16 @@
1
- # 1 精度预检工具
1
+
2
+
3
+ # 1 数据采集
4
+
5
+ 1. dump.json中API或Module统计信息里出现null或None值的原因是什么?
6
+
7
+ dump.json里出现null或None值的可能性较多,常见的场景有:
8
+
9
+ - 输入或者输出参数本身是一个None值。
10
+ - 输入参数或输出参数类型当前工具不支持,会有日志打印提醒。
11
+ - 输入或者输出tensor的dtype为bool时,Mean和Norm等字段为null。
12
+
13
+ # 2 精度预检(PyTorch)
2
14
 
3
15
  1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)?
4
16
 
@@ -52,20 +64,20 @@
52
64
  | `__matmul__` | 矩阵乘法 |
53
65
  | `__mod__` | % |
54
66
  | `__mul__` | * |
55
- | `__nonzero__` | 同`__bool__` |
67
+ | `__nonzero__` | 同 `__bool__` |
56
68
  | `__or__` | \| |
57
69
  | `__radd__` | +(反向) |
58
70
  | `__rmul__` | *(反向) |
59
71
  | `__rshift__` | >> |
60
72
  | `__sub__` | - |
61
- | `__truediv__` | 同`__div__` |
73
+ | `__truediv__` | 同 `__div__` |
62
74
  | `__xor__` | ^ |
63
75
 
64
- # 2 精度比对工具
76
+ # 3 精度比对(PyTorch)
65
77
 
66
- ## 2.1 工具使用
78
+ ## 3.1 工具使用
67
79
 
68
- ### 2.1.1 dump 指定融合算子
80
+ ### 3.1.1 dump 指定融合算子
69
81
 
70
82
  数据采集当前支持融合算子的输入输出,需要在 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 中添加,比如以下代码段调用的 softmax 融合算子。
71
83
 
@@ -83,7 +95,7 @@ def npu_forward_fused_softmax(self, input_, mask):
83
95
 
84
96
  (npu_scaled_masked_softmax 融合算子工具已支持 dump,本例仅供参考)。
85
97
 
86
- ## 2.2 常见问题
98
+ ## 3.2 常见问题
87
99
 
88
100
  1. 在同一个目录多次执行 dump 会冲突吗?
89
101
 
@@ -97,7 +109,7 @@ def npu_forward_fused_softmax(self, input_, mask):
97
109
 
98
110
  答:torch 版本和硬件差异属于正常情况。
99
111
 
100
- ## 2.3 异常情况
112
+ ## 3.3 异常情况
101
113
 
102
114
  1. HCCL 报错: error code: EI0006。
103
115
 
@@ -168,9 +180,9 @@ def npu_forward_fused_softmax(self, input_, mask):
168
180
 
169
181
  答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
170
182
 
171
- 11. 添加 msprobe 工具后 F.gelu 触发 ValueError 报错:`activation_func must be F.gelu`等。
183
+ 11. 添加 msprobe 工具后 F.gelu 触发 ValueError 报错:`activation_func must be F.gelu` 等。以及采集 Megatron 数据时报错:`ValueError(Only support fusion of gelu and swiglu)`。
172
184
 
173
- 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `functional: ` 下的 `-gelu`,工具会跳过采集该 API。如果需要采集关键位置 api 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
185
+ 答:这一类问题是因为工具本身封装了 torch 算子,所以校验算子名时会报错。注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。如果需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
174
186
 
175
187
  12. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
176
188
 
Binary file
Binary file
Binary file
@@ -1 +1,2 @@
1
1
  from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
2
+ from msprobe.mindspore.common.utils import seed_all
@@ -92,6 +92,23 @@ class ApiAccuracyChecker:
92
92
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
93
93
  return output_list
94
94
 
95
+ @staticmethod
96
+ def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
97
+ '''
98
+ Args:
99
+ api_info: ApiInfo
100
+ forward_or_backward: str
101
+ Returns:
102
+ ApiInputAggregation
103
+ '''
104
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
105
+ kwargs = api_info.get_kwargs()
106
+ if forward_or_backward == Const.FORWARD:
107
+ gradient_inputs = None
108
+ else:
109
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
110
+ return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
111
+
95
112
  def parse(self, api_info_path):
96
113
  with FileOpen(api_info_path, "r") as f:
97
114
  api_info_dict = json.load(f)
@@ -131,32 +148,39 @@ class ApiAccuracyChecker:
131
148
  def run_and_compare(self):
132
149
  for api_name_str, api_info in self.api_infos.items():
133
150
  if not api_info.check_forward_info():
134
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
151
+ logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check.")
152
+ continue
153
+ try:
154
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
155
+ except Exception as e:
156
+ logger.warning(f"exception occurs when getting inputs for {api_name_str} forward api. "
157
+ f"skip forward and backward check. detailed exception information: {e}.")
135
158
  continue
136
- forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
137
- kwargs = api_info.get_kwargs()
138
- forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
139
159
  forward_output_list = None
140
160
  try:
141
161
  forward_output_list = \
142
162
  self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
143
163
  except Exception as e:
144
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
145
- f"detailed exception information: {e}")
164
+ logger.warning(f"exception occurs when running and comparing {api_name_str} forward api. "
165
+ f"detailed exception information: {e}.")
146
166
  self.record(forward_output_list)
147
167
 
148
168
  if not api_info.check_backward_info():
149
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
169
+ logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check.")
170
+ continue
171
+ try:
172
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
173
+ except Exception as e:
174
+ logger.warning(f"exception occurs when getting inputs for {api_name_str} backward api. "
175
+ f"skip backward check. detailed exception information: {e}.")
150
176
  continue
151
- gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
152
- backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
153
177
  backward_output_list = None
154
178
  try:
155
179
  backward_output_list = \
156
180
  self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
157
181
  except Exception as e:
158
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
159
- f"detailed exception information: {e}")
182
+ logger.warning(f"exception occurs when running and comparing {api_name_str} backward api. "
183
+ f"detailed exception information: {e}.")
160
184
  self.record(backward_output_list)
161
185
 
162
186
  def record(self, output_list):
@@ -3,9 +3,16 @@ from msprobe.core.common.const import Const
3
3
  from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
4
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
5
  from msprobe.mindspore.common.log import logger
6
+ from msprobe.core.common.utils import is_invalid_pattern
6
7
 
7
8
  class ApiInfo:
8
9
  def __init__(self, api_name):
10
+ if not isinstance(api_name, str):
11
+ err_msg = "ApiInfo.__init__ failed: api_name is not a string"
12
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
13
+ if is_invalid_pattern(api_name):
14
+ err_msg = "ApiInfo.__init__ failed: api_name contain illegal character"
15
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
9
16
  self.api_name = api_name
10
17
  self.forward_info = None
11
18
  self.backward_info = None
@@ -1,17 +1,19 @@
1
1
  from msprobe.core.data_dump.scope import ModuleRangeScope
2
2
  from msprobe.core.common.const import Const
3
- from msprobe.mindspore.common.log import logger
4
3
 
5
4
 
6
5
  class CellProcessor:
7
6
  cell_count = {}
7
+ cell_stack = []
8
+ api_parent_node = ""
9
+ module_node = {}
8
10
 
9
11
  def __init__(self, scope):
10
12
  if isinstance(scope, ModuleRangeScope):
11
13
  self.scope = scope
12
14
  else:
13
15
  self.scope = None
14
-
16
+
15
17
  @staticmethod
16
18
  def set_cell_count(cell_name):
17
19
  if cell_name not in CellProcessor.cell_count:
@@ -20,14 +22,36 @@ class CellProcessor:
20
22
  CellProcessor.cell_count[cell_name] += 1
21
23
  return CellProcessor.cell_count[cell_name]
22
24
 
25
+ @classmethod
26
+ def reset_cell_stats(cls):
27
+ cls.cell_count = {}
28
+ cls.cell_stack = []
29
+ cls.api_parent_node = ""
30
+ cls.module_node = {}
31
+
23
32
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
24
33
  def begin_hook(cell, input):
25
34
  index = self.set_cell_count(name_prefix)
26
35
  cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
36
+ if CellProcessor.cell_stack:
37
+ CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
38
+ else:
39
+ CellProcessor.module_node[full_name] = None
40
+
41
+ CellProcessor.cell_stack.append(full_name)
42
+ CellProcessor.api_parent_node = full_name
43
+
27
44
  if self.scope:
28
45
  self.scope.begin_module(full_name)
29
-
46
+
30
47
  def end_hook(cell, input, output):
48
+ if CellProcessor.cell_stack:
49
+ CellProcessor.cell_stack.pop()
50
+ if CellProcessor.cell_stack:
51
+ CellProcessor.api_parent_node = CellProcessor.cell_stack[-1]
52
+ else:
53
+ CellProcessor.api_parent_node = None
54
+
31
55
  if self.scope:
32
56
  self.scope.end_module(cell.mindstudio_reserved_name)
33
57
 
@@ -39,12 +39,14 @@ class Const:
39
39
  OPS_DATA_PREFIX = "Functional."
40
40
  MINT_DATA_PREFIX = "Mint."
41
41
  MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
42
+ DISTRIBUTED_DATA_PREFIX = "Distributed."
42
43
 
43
44
  SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
44
45
  SUPPORTED_TENSOR_LIST_KEY = "tensor"
45
46
  SUPPORTED_OPS_LIST_KEY = "ops"
46
47
  SUPPORTED_MINT_LIST_KEY = "mint.ops"
47
48
  SUPPORTED__MINT_NN_FUNC_LIST_KEY = "mint.nn.functional"
49
+ SUPPORTED_COMM_LIST_KEY = "communication.comm_func"
48
50
 
49
51
  DROPOUT_API_NAME_PREFIX = "dropout"
50
52
 
@@ -12,12 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
+
15
16
  import os
17
+ import random
18
+
16
19
  import mindspore as ms
17
20
 
18
21
  from msprobe.core.common.exceptions import DistributedNotInitializedError
19
22
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
20
23
  from msprobe.core.common.log import logger
24
+ from msprobe.core.common.const import Const
25
+ from msprobe.core.common.utils import CompareException, check_seed_all
21
26
 
22
27
 
23
28
  def get_rank_if_initialized():
@@ -53,12 +58,15 @@ def list_lowest_level_directories(root_dir):
53
58
  check_path_exists(root_dir)
54
59
  lowest_level_dirs = []
55
60
 
56
- def recurse_dirs(current_dir):
61
+ def recurse_dirs(current_dir, depth=0):
62
+ if depth > Const.MAX_DEPTH:
63
+ logger.error(f'The directory {current_dir} has more than {Const.MAX_DEPTH} levels.')
64
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
57
65
  for entry in os.listdir(current_dir):
58
66
  full_path = os.path.join(current_dir, entry)
59
67
  if os.path.isdir(full_path):
60
68
  if any(os.path.isdir(os.path.join(full_path, subentry)) for subentry in os.listdir(full_path)):
61
- recurse_dirs(full_path)
69
+ recurse_dirs(full_path, depth=depth+1)
62
70
  else:
63
71
  lowest_level_dirs.append(full_path)
64
72
 
@@ -66,6 +74,14 @@ def list_lowest_level_directories(root_dir):
66
74
  return lowest_level_dirs
67
75
 
68
76
 
77
+ def seed_all(seed=1234, mode=False):
78
+ check_seed_all(seed, mode)
79
+ os.environ['PYTHONHASHSEED'] = str(seed)
80
+ ms.set_seed(seed)
81
+ random.seed(seed)
82
+ ms.set_context(deterministic="ON" if mode else "OFF")
83
+ os.environ['HCCL_DETERMINISTIC'] = str(mode)
84
+
69
85
 
70
86
  class MsprobeStep(ms.train.Callback):
71
87
 
@@ -1,19 +1,3 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
1
  import os
18
2
  from msprobe.core.common.utils import CompareException, check_compare_param, \
19
3
  check_configuration_param, task_dumppath_get
@@ -24,6 +8,7 @@ from msprobe.mindspore.compare.ms_compare import MSComparator
24
8
  from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
9
  from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
26
10
 
11
+
27
12
  def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
28
13
  if kwargs.get('suffix'):
29
14
  logger.error("Argument 'suffix' is not supported for compare_distributed.")
@@ -54,15 +39,17 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
54
39
  }
55
40
  try:
56
41
  summary_compare, md5_compare = task_dumppath_get(dump_result_param)
57
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
42
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
43
+ dump_result_param.get('is_print_compare_log', True))
58
44
  create_directory(output_path)
59
- check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
45
+ check_compare_param(dump_result_param, output_path,
46
+ summary_compare=summary_compare, md5_compare=md5_compare)
60
47
  except (CompareException, FileCheckException) as error:
61
48
  logger.error('Compare failed. Please check the arguments and do it again!')
62
49
  raise CompareException(error.code) from error
63
50
  ms_comparator = MSComparator()
64
- ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
65
- md5_compare=md5_compare, **kwargs)
51
+ ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
52
+ summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
66
53
 
67
54
 
68
55
  def ms_graph_compare(inputs, outputs):
@@ -71,5 +58,5 @@ def ms_graph_compare(inputs, outputs):
71
58
  except (CompareException, FileCheckException) as error:
72
59
  logger.error('Compare failed. Please check the arguments and do it again!')
73
60
  return
74
- msComparator = GraphMSComparator(inputs, outputs)
75
- msComparator.compare_core()
61
+ ms_comparator = GraphMSComparator(inputs, outputs)
62
+ ms_comparator.compare_core()