mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,207 @@
1
+ # Ascend模型梯度状态监测工具
2
+
3
+ 梯度状态监测工具提供了两种能力:
4
+
5
+ - 将模型权重的梯度数据导出。这种功能可以将模型权重的梯度值以统计量的形式采集出来,用以分析问题。
6
+ - 将两份梯度数据进行相似度对比。在有标杆问题中,可以确认训练过程中精度问题出现的step,以及抓取反向过程中的问题。
7
+
8
+ 工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。
9
+
10
+ ## 工具特性
11
+
12
+ - 使用便捷,无需在训练流程里插入代码
13
+ - 可以精准定位问题出现的step
14
+
15
+ ## 使用方式
16
+
17
+ ### 梯度数据导出
18
+
19
+ 1. 创建配置文件config.json,样例如下:
20
+
21
+ ```json
22
+ {
23
+ "task": "grad_probe",
24
+ "dump_path": "./dump_path",
25
+ "rank": [],
26
+ "step": [],
27
+ "grad_probe": {
28
+ "grad_level": "L1",
29
+ "param_list": [],
30
+ "bounds": [-1, 0, 1]
31
+ }
32
+ }
33
+ ```
34
+ > step指的是优化器被调用的次数(并非模型跑的step,某些step,例如loss为nan时,不会调用优化器)
35
+
36
+ **参数说明**
37
+
38
+ | 参数 | 说明 | 输入类型 | 是否必选 |
39
+ |--------------------------------|-----------------------------------|-----------------|----------|
40
+ | task | 填为"grad_probe"。 | str | 是 |
41
+ | grad_level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2|str | 是 |
42
+ | param_list | 权重名称列表,表示需要监控的权重。列表为空就表示监控所有权重。 | List[str] | 是 |
43
+ | rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。列表为空就表示导出所有rank的数据。(MindSpore静态图模式下,当前暂不支持指定rank功能) | List[int] | 是 |
44
+ | step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step的数据。(MindSpore静态图模式下,当前暂不支持指定step功能) | List[int] | 是 |
45
+ | bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列。可以使用默认值[-1, 0, 1] | List[float] | 是 |
46
+ | dump_path | 输出目录。如果不存在就会创建一个新目录。 | str | 是 |
47
+
48
+ **不同级别的level的导出数据**
49
+
50
+
51
+ | 级别 | 特征数据表头 | 是否有方向数据 |
52
+ | ---- | ------------------------------------------------------------ | -------------- |
53
+ | L0 | ("param_name", "MD5", "max", "min", "norm", "shape") | 否 |
54
+ | L1 | ("param_name", "max", "min", "norm", "shape") | 是 |
55
+ | L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 |
56
+
57
+ intervals就是根据值分布bounds划分出的区间。
58
+ MindSpore静态图模式下,L0级别中暂不支持"MD5"
59
+
60
+ **方向数据解释**
61
+
62
+ 因为模型的参数往往非常大,所以存储真实数据是不可接受的,这里折衷一下,只存储梯度数据的正负号(一个布尔值),也就是方向。
63
+
64
+ **bounds和值分布解释**
65
+
66
+ + 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。
67
+ + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示:
68
+ ![Alt text](img/image-1.png)
69
+
70
+ 2. 插入代码。示例代码如下:
71
+
72
+ - PyTorch框架:模型构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`模型`作为参数传入。
73
+ ```python
74
+ from msprobe.pytorch import PrecisionDebugger
75
+ debugger = PrecisionDebugger("config_json_path")
76
+ debugger.monitor(model)
77
+ ```
78
+ - MindSpore框架:优化器构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`优化器`作为参数传入。
79
+ ```python
80
+ from msprobe.mindspore import PrecisionDebugger
81
+ debugger = PrecisionDebugger("config_json_path")
82
+ debugger.monitor(optimizer)
83
+ ```
84
+
85
+ 3. 结束监控(MindSpore静态图模式下需要)
86
+
87
+ 在训练结束之后,调用stop接口
88
+
89
+ ```python
90
+ gm.stop()
91
+ ```
92
+
93
+ ### 输出结果
94
+ **输出目录结构**(以level配置L2为例)
95
+
96
+ ```bash
97
+ {dump_path}
98
+ ├── rank{rank_id}
99
+ │ ├── grad_summary_{step}.csv
100
+ │ ├── step{step}
101
+ │ │ ├── {param_name}.npy
102
+ ```
103
+ + {timestamp}:梯度工具导出数据的时候会在output_path下生成一个时间戳目录,然后在这个时间戳目录下输出结果。
104
+ + rank_{rank_id}:在分布式场景下,会记录卡的rank_id。非分布式场景下,如果是CPU则记录进程号,如果是CPU或GPU则记录卡号
105
+ + grad_summary_{step}.csv:会分step记录每一步的梯度数据统计值。
106
+ + step_{step}:这个目录下会存放该step的梯度的方向数据。
107
+ + {param_name}.pt(npy):模型参数的梯度方向数据,PyTorch保存的是pt文件,MindSpore是npy文件。
108
+
109
+ **grad_summary_{step}.csv**
110
+
111
+ 样例如下:
112
+
113
+ ![Alt text](img/image.png)
114
+
115
+ | 字段 | 含义 |
116
+ | --------------------- | ------------------------------------------------------------|
117
+ | Param_name | 模型参数名称。 |
118
+ | MD5 | 梯度数据的MD5值。 |
119
+ | (-inf, -0.01]...[0.01, inf) | 梯度值落在区间内的元素个数占总元素的比例。 |
120
+ | =0 | 梯度为0的元素个数占总元素的比例。 |
121
+ | Max | 最大值。 |
122
+ | Min | 最小值。 |
123
+ | Norm | L2norm值。 |
124
+ | Shape | 形状。 |
125
+
126
+ ### 梯度相似度比对
127
+
128
+ 会根据所导出的权重,分step比对梯度相似度,输出每个权重的梯度相似度和总的梯度相似度。单个权重的梯度相似度为两份方向数据的重合度,总的梯度相似度为每个权重的梯度相似度按元素个数加权。
129
+
130
+ #### 前提条件
131
+
132
+ - 相同配置下,以Level为L1或L2分别采集npu和gpu环境下的梯度数据。
133
+ - 将两份梯度数据传到同一环境下。
134
+
135
+ #### 使用方式
136
+
137
+
138
+ 新建如下Python脚本,传入npu和gpu的dump_path以及输出目录,比对结果输出目录不存在的话会新建:
139
+
140
+ ```python
141
+ from msprobe import *
142
+ GradComparator.compare_distributed("配置文件里写的dump_path",
143
+ "配置文件里写的dump_path",
144
+ "比对结果输出目录")
145
+ ```
146
+
147
+
148
+ ### 比对结果
149
+
150
+ **输出目录结构**
151
+
152
+ 如下为多卡比对结果,单卡则没有rank_{rank_id}这一级目录。
153
+
154
+ ```bash
155
+ 比对结果输出目录
156
+ ├── rank{rank_id}
157
+ │ ├── similarities.csv
158
+ │ └── similarities_picture
159
+ │ ├── {param_name}.png
160
+ │ └── summary_similarities.png
161
+ ```
162
+
163
+ **问题界定**
164
+
165
+ 原则:对于任意权重,第0步的梯度相似度低于0.97,或者某一步的梯度相似度下降超过0.03,认为这一步存在精度问题。例子如下:
166
+
167
+ - 第0步相似度低于0.97
168
+
169
+ ![Alt text](img/image-3.png)
170
+
171
+ - 第3步相似度下降超过0.03
172
+
173
+ ![Alt text](img/image-4.png)
174
+
175
+ - 正常情况
176
+
177
+ ![Alt text](img/image-2.png)
178
+
179
+ 这个原则是一个经验性的指标,并不是严格的标注,还需要结合实际情况具体分析。
180
+
181
+ ## 公开接口
182
+
183
+ **接口说明**
184
+
185
+ ```python
186
+ PrecisionDebugger.monitor(module)
187
+ ```
188
+
189
+ | 参数 | 说明 | 是否必选 |
190
+ | ----- | -------------------- | -------- |
191
+ | module |Pytorch框架下传入模型,必须是torch.nn.Module;MindSpore框架下传入优化器。 | 是 |
192
+
193
+
194
+ **接口说明**
195
+
196
+ ```python
197
+ GradComparator.compare_distributed(dump_path1, dump_path2, output_path)
198
+ ```
199
+
200
+ | 参数 | 说明 | 是否必选 |
201
+ | ----- | -------------------- | -------- |
202
+ | dump_path1 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path。 | 是 |
203
+ | dump_path2 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path,与dump_path1可以互换。 | 是 |
204
+ | output_path |输出结果目录,不存在会新建。 | 是 |
205
+
206
+
207
+ # FAQ
Binary file
Binary file
Binary file
Binary file
Binary file
File without changes
@@ -0,0 +1,246 @@
1
+ import json
2
+ import os
3
+
4
+ from msprobe.core.common.file_check import FileOpen
5
+ from msprobe.core.common.utils import write_csv, add_time_as_suffix
6
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
+ from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
11
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
12
+
13
+
14
+ class BasicInfoAndStatus:
15
+ def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
16
+ self.api_name = api_name
17
+ self.bench_dtype = bench_dtype
18
+ self.tested_dtype = tested_dtype
19
+ self.shape = shape
20
+ self.status = status
21
+ self.err_msg = err_msg
22
+
23
+ class ResultCsvEntry:
24
+ def __init__(self) -> None:
25
+ self.forward_pass_status = None
26
+ self.backward_pass_status = None
27
+ self.forward_err_msg = ""
28
+ self.backward_err_msg = ""
29
+ self.overall_err_msg = None
30
+
31
+
32
+ class ApiAccuracyChecker:
33
+ def __init__(self):
34
+ self.api_infos = dict()
35
+ self.results = dict()
36
+
37
+ @staticmethod
38
+ def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
39
+ '''
40
+ Args:
41
+ api_info: ApiInfo
42
+ api_name_str: str
43
+ api_input_aggregation: ApiInputAggregation
44
+ forward_or_backward: str: Union["forward", "backward"]
45
+
46
+ Return:
47
+ output_list: List[tuple(str, str, BasicInfoAndStatus, dict{str: CompareResult})]
48
+
49
+ Description:
50
+ get mindspore api output, run torch api and get output.
51
+ compare output.
52
+ record compare result.
53
+ '''
54
+ # get output
55
+ if global_context.get_is_constructed():
56
+ # constructed situation, need use constructed input to run mindspore api getting tested_output
57
+ tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
58
+ else:
59
+ tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
60
+ bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
61
+
62
+ # compare output
63
+ output_list = []
64
+ for i, (bench_out, tested_out) in enumerate(zip(bench_outputs, tested_outputs)):
65
+ api_name_with_slot = Const.SEP.join([api_name_str, forward_or_backward, Const.OUTPUT, str(i)])
66
+ bench_dtype = bench_out.get_dtype()
67
+ tested_dtype = tested_out.get_dtype()
68
+ shape = bench_out.get_shape()
69
+
70
+ compare_result_dict = dict()
71
+ for compare_algorithm_name, compare_algorithm in compare_algorithms.items():
72
+ compare_result = compare_algorithm(bench_out, tested_out)
73
+ compare_result_dict[compare_algorithm_name] = compare_result
74
+
75
+ if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
76
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
77
+ status = CompareConst.PASS
78
+ err_msg = ""
79
+ else:
80
+ status = CompareConst.ERROR
81
+ err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
82
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
83
+ basic_info_status = \
84
+ BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
85
+ output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
86
+ return output_list
87
+
88
+ def parse(self, api_info_path):
89
+ with FileOpen(api_info_path, "r") as f:
90
+ api_info_dict = json.load(f)
91
+
92
+ # init global context
93
+ task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
94
+ "task field in api_info.json",accepted_type=str,
95
+ accepted_value=(MsCompareConst.STATISTICS_TASK,
96
+ MsCompareConst.TENSOR_TASK))
97
+ is_constructed = task == MsCompareConst.STATISTICS_TASK
98
+ if not is_constructed:
99
+ dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
100
+ "dump_data_dir field in api_info.json", accepted_type=str)
101
+ else:
102
+ dump_data_dir = ""
103
+ global_context.init(is_constructed, dump_data_dir)
104
+
105
+ api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
106
+ "data field in api_info.json", accepted_type=dict)
107
+ for api_name, api_info in api_info_data.items():
108
+ is_mint = api_name.split(Const.SEP)[0] in \
109
+ (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
110
+ if not is_mint:
111
+ continue
112
+ forbackward_str = api_name.split(Const.SEP)[-1]
113
+ if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
114
+ logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
115
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
116
+ if api_name not in self.api_infos:
117
+ self.api_infos[api_name] = ApiInfo(api_name)
118
+
119
+ if forbackward_str == Const.FORWARD:
120
+ self.api_infos[api_name].load_forward_info(api_info)
121
+ else:
122
+ self.api_infos[api_name].load_backward_info(api_info)
123
+
124
+ def run_and_compare(self):
125
+ for api_name_str, api_info in self.api_infos.items():
126
+ if not api_info.check_forward_info():
127
+ logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check")
128
+ continue
129
+ forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
130
+ kwargs = api_info.get_kwargs()
131
+ forward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, None)
132
+ forward_output_list = None
133
+ try:
134
+ forward_output_list = \
135
+ self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
136
+ except Exception as e:
137
+ logger.warning(f"exception occurs when running and comparing {api_name_str} forward api"
138
+ f"detailed exception information: {e}")
139
+ self.record(forward_output_list)
140
+
141
+ if not api_info.check_backward_info():
142
+ logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check")
143
+ continue
144
+ gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
145
+ backward_inputs_aggregation = ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
146
+ backward_output_list = None
147
+ try:
148
+ backward_output_list = \
149
+ self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
150
+ except Exception as e:
151
+ logger.warning(f"exception occurs when running and comparing {api_name_str} backward api"
152
+ f"detailed exception information: {e}")
153
+ self.record(backward_output_list)
154
+
155
+ def record(self, output_list):
156
+ if output_list is None:
157
+ return
158
+ for output in output_list:
159
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
160
+ key = tuple([api_real_name, forward_or_backward])
161
+ if key not in self.results:
162
+ self.results[key] = []
163
+ self.results[key].append(tuple([basic_info, compare_result_dict]))
164
+
165
+
166
+ def to_detail_csv(self, csv_dir):
167
+ # detail_csv
168
+ detail_csv = []
169
+ detail_csv_header_basic_info = [
170
+ MsCompareConst.DETAIL_CSV_API_NAME,
171
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
172
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
173
+ MsCompareConst.DETAIL_CSV_SHAPE,
174
+ ]
175
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
176
+ detail_csv_header_status = [
177
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
178
+ MsCompareConst.DETAIL_CSV_MESSAGE,
179
+ ]
180
+
181
+ detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
182
+ detail_csv.append(detail_csv_header)
183
+
184
+ for _, results in self.results.items():
185
+ # detail csv
186
+ for res in results:
187
+ basic_info, compare_result_dict = res
188
+ csv_row_basic_info = \
189
+ [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
190
+ csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
191
+ for algorithm_name in detail_csv_header_compare_result)
192
+ csv_row_status = [basic_info.status, basic_info.err_msg]
193
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
194
+ detail_csv.append(csv_row)
195
+
196
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
197
+ write_csv(detail_csv, file_name, mode="w")
198
+
199
+
200
+ def to_result_csv(self, csv_dir):
201
+ result_csv_dict = dict()
202
+ for key, results in self.results.items():
203
+ api_real_name, forward_or_backward = key
204
+ forward_or_backward_pass_status = CompareConst.PASS
205
+ forward_or_backward_overall_err_msg = ""
206
+ # detail csv
207
+ for res in results:
208
+ basic_info, _ = res
209
+ if basic_info.status != CompareConst.PASS:
210
+ forward_or_backward_pass_status = CompareConst.ERROR
211
+ forward_or_backward_overall_err_msg += basic_info.err_msg
212
+ forward_or_backward_overall_err_msg = \
213
+ "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
214
+
215
+ #result_csv_dict
216
+ if api_real_name not in result_csv_dict:
217
+ result_csv_dict[api_real_name] = ResultCsvEntry()
218
+ if forward_or_backward == Const.FORWARD:
219
+ result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
220
+ result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
221
+ else:
222
+ result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
223
+ result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
224
+
225
+ #result_csv
226
+ result_csv = []
227
+ result_csv_header = [
228
+ MsCompareConst.DETAIL_CSV_API_NAME,
229
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
230
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
231
+ MsCompareConst.DETAIL_CSV_MESSAGE,
232
+ ]
233
+ result_csv.append(result_csv_header)
234
+
235
+ for api_name, result_csv_entry in result_csv_dict.items():
236
+ if result_csv_entry.forward_pass_status == CompareConst.PASS and \
237
+ result_csv_entry.backward_pass_status == CompareConst.PASS:
238
+ overall_err_msg = ""
239
+ else:
240
+ overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
241
+ row = [api_name, result_csv_entry.forward_pass_status,
242
+ result_csv_entry.backward_pass_status, overall_err_msg]
243
+ result_csv.append(row)
244
+
245
+ file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
246
+ write_csv(result_csv, file_name, mode="w")
@@ -0,0 +1,69 @@
1
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
+ from msprobe.core.common.log import logger
6
+
7
+ class ApiInfo:
8
+ def __init__(self, api_name):
9
+ self.api_name = api_name
10
+ self.forward_info = None
11
+ self.backward_info = None
12
+
13
+ def load_forward_info(self, forward_info_dict):
14
+ self.forward_info = forward_info_dict
15
+
16
+ def load_backward_info(self, backward_info_dict):
17
+ self.backward_info = backward_info_dict
18
+
19
+ def check_forward_info(self):
20
+ return self.forward_info is not None
21
+
22
+ def check_backward_info(self):
23
+ return self.backward_info is not None
24
+
25
+ def get_compute_element_list(self, forward_or_backward, input_or_output):
26
+ '''
27
+ Args:
28
+ forward_or_backward: str, Union["forward", "backward"]
29
+ input_or_output: str, Union["input", "output"]
30
+
31
+ Return:
32
+ compute_element_list: List[ComputeElement]
33
+ '''
34
+ mapping = {
35
+ (Const.FORWARD, Const.INPUT): [self.forward_info, Const.INPUT_ARGS,
36
+ f"input_args field of {self.api_name} forward api in api_info.json"],
37
+ (Const.FORWARD, Const.OUTPUT): [self.forward_info, Const.OUTPUT,
38
+ f"output field of {self.api_name} forward api in api_info.json"],
39
+ (Const.BACKWARD, Const.INPUT): [self.backward_info, Const.INPUT,
40
+ f"input field of {self.api_name} backward api in api_info.json"],
41
+ (Const.BACKWARD, Const.OUTPUT): [self.backward_info, Const.OUTPUT,
42
+ f"output field of {self.api_name} backward api in api_info.json"]
43
+ }
44
+ dict_instance, key, key_desc = mapping.get((forward_or_backward, input_or_output))
45
+ compute_element_info_list = check_and_get_from_json_dict(dict_instance, key, key_desc, accepted_type=list)
46
+ compute_element_list = [ComputeElement(compute_element_info=compute_element_info)
47
+ for compute_element_info in compute_element_info_list]
48
+ return compute_element_list
49
+
50
+ def get_kwargs(self):
51
+ '''
52
+ Return:
53
+ kwargs_compute_element_dict: dict{str: ComputeElement}
54
+ '''
55
+ kwargs_dict = check_and_get_from_json_dict(self.forward_info, Const.INPUT_KWARGS,
56
+ "input_kwargs in api_info.json", accepted_type=dict)
57
+ for key_str, compute_element_info in kwargs_dict.items():
58
+ if not isinstance(key_str, str):
59
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
60
+ logger.error_log_with_exp(err_msg,
61
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
62
+ if not isinstance(compute_element_info, (list, dict)):
63
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
64
+ logger.error_log_with_exp(err_msg,
65
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
66
+ kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
67
+ for key_str, compute_element_info in kwargs_dict.items()}
68
+ return kwargs_compute_element_dict
69
+