mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -12,7 +12,7 @@ msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方
12
12
 
13
13
  通过加载dump配置文件的方式来确定dump操作的详细配置。
14
14
 
15
- 可以在from msprobe.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
15
+ PrecisionDebugger可以在from msprobe.mindspore import PrecisionDebugger之后的位置添加。详细使用可参考“**示例代码**”。
16
16
 
17
17
  **原型**
18
18
 
@@ -24,7 +24,7 @@ PrecisionDebugger(config_path=None)
24
24
 
25
25
  | 参数名 | 说明 | 是否必选 |
26
26
  | ----------- | ------------------------------------------------------------ | -------- |
27
- | config_path | 指定dump配置文件路径,String类型。参数示例:"./config.json"。未配置该路径时,默认使用[config.json](../../config)文件的默认配置。config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config/config.json)文件。 | 否 |
27
+ | config_path | 指定dump配置文件路径,String类型。参数示例:"./config.json"。未配置该路径时,默认使用[config.json](../../config)文件的默认配置。config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config/config.json)文件。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md)》。 | 否 |
28
28
 
29
29
  ### start函数
30
30
 
@@ -32,16 +32,64 @@ PrecisionDebugger(config_path=None)
32
32
 
33
33
  启动函数。
34
34
 
35
+ 在模型初始化之后的位置添加。需要与stop函数一起添加在for循环内。
36
+
35
37
  **原型**
36
38
 
37
39
  ```Python
38
- debugger.start()
40
+ debugger.start(model = None)
39
41
  ```
40
42
 
41
- 该函数为类函数,可以使用debugger.start()也可以使用PrecisionDebugger.start()
43
+ 该函数为类函数,可以使用debugger.start(model = None)也可以使用PrecisionDebugger.start(model = None)
44
+
45
+
46
+ **参数说明**
47
+
48
+ | 参数名 | 说明 | 是否必选 |
49
+ | ----------- |---------------------------------------------------------------------------------------| -------- |
50
+ | model | 指具体的mindspore.nn.Cell,默认未配置,L1级别下传入model可以使能对primitive op的dump,否则无法dump primitive op。 | 否 |
51
+
52
+
53
+ ### stop函数
54
+
55
+ **功能说明**
56
+
57
+ dump停止函数。
58
+
59
+ 在**start**函数之后的任意位置添加。需要与start函数一起添加在for循环内。若需要dump反向数据,则需要添加在反向计算代码之后。
60
+
61
+ 仅MindSpore动态图场景支持。
62
+
63
+ **原型**
64
+
65
+ ```Python
66
+ debugger.stop()
67
+ ```
68
+
69
+ 该函数为类函数,可以使用debugger.stop()也可以使用PrecisionDebugger.stop()。
70
+
71
+ ### step函数
72
+
73
+ **功能说明**
74
+
75
+ 结束标识。
76
+
77
+ 在最后一个**stop**函数后或一个step结束的位置添加。
78
+
79
+ 仅MindSpore动态图场景支持。
80
+
81
+ **原型**
82
+
83
+ ```Python
84
+ debugger.step()
85
+ ```
86
+
87
+ 该函数为类函数,可以使用debugger.step()也可以使用PrecisionDebugger.step()。
42
88
 
43
89
  ## 示例代码
44
90
 
91
+ ### MindSpore静态图场景
92
+
45
93
  ```Python
46
94
  from msprobe.mindspore import PrecisionDebugger
47
95
  debugger = PrecisionDebugger(config_path="./config.json")
@@ -51,15 +99,119 @@ debugger.start()
51
99
  ...
52
100
  ```
53
101
 
102
+ ### MindSpore动态图场景
103
+
104
+ 当使用模型使用for循环时,在每个迭代的开始插入debugger.start(),在每个迭代的结束插入debugger.stop()与debugger.step():
105
+
106
+ ```Python
107
+ import mindspore as ms
108
+ from msprobe.mindspore import PrecisionDebugger
109
+
110
+ # 请勿将PrecisionDebugger的初始化插入到循环代码中
111
+ debugger = PrecisionDebugger(config_path="./config.json")
112
+
113
+ # 模型、损失函数的定义以及初始化等操作
114
+ # ...
115
+
116
+ # 数据集迭代的地方往往是模型开始训练的地方
117
+ for data, label in data_loader:
118
+ debugger.start() # 开启数据dump
119
+ net = Model()
120
+ # 如下是模型每个step执行的逻辑
121
+ grad_net = ms.grad(net)(data)
122
+ # ...
123
+ debugger.stop() # 关闭数据dump
124
+ debugger.step() # 结束一个step的dump
125
+ ```
126
+
127
+ 当使用模型的train方法而非for循环时,可以通过在callbacks参数中传入MsprobeStep(debugger):
128
+
129
+ ```Python
130
+ from msprobe.mindspore.common.utils import MsprobeStep
131
+ from msprobe.mindspore import PrecisionDebugger
132
+
133
+ # 初始化PrecisionDebugger
134
+ debugger = PrecisionDebugger(config_path="./config.json")
135
+
136
+ # 自动在每个step开始时调用start(),在每个step结束时调用stop()和step()。
137
+ # 这意味着您无需手动在循环内添加start、stop和step函数,框架会自动完成数据的dump操作。
138
+ trainer.train(1, dataset_train, callbacks=[loss_monior, MsprobeStep(debugger)])
139
+
140
+ ```
141
+
54
142
  ## dump结果文件介绍
55
143
 
144
+ ### MindSpore静态图场景
145
+
56
146
  训练结束后,工具将dump的数据保存在dump_path参数指定的目录下。
57
147
 
58
- - levelL1
148
+ - jit_levelO0/O1
59
149
 
60
150
  dump结果目录请参见MindSpore官网中的《[同步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%90%8C%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
61
151
 
62
- - levelL2
152
+ - jit_levelO2
63
153
 
64
154
  dump结果目录请参见MindSpore官网中的《[异步Dump数据对象目录](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.0rc2/debug/dump.html#%E5%BC%82%E6%AD%A5dump%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95)》。
65
155
 
156
+ jit_level请参见[mindspore.set_context](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.JitConfig.html#mindspore-jitconfig)配置jit_config。
157
+
158
+ ### MindSpore动态图场景
159
+
160
+ 训练结束后,工具将dump的数据保存在dump_path参数指定的目录下。
161
+
162
+ dump结果目录结构示例如下:
163
+
164
+ ```bash
165
+ ├── dump_path
166
+ │ ├── step0
167
+ │ | ├── rank0
168
+ │ | │ ├── dump_tensor_data
169
+ | | | | ├── MintFunctional.relu.0.backward.input.0.npy
170
+ | | | | ├── Mint.abs.0.forward.input.0.npy
171
+ | | | | ├── Functional.split.0.forward.input.0.npy
172
+ | | | | ├── Tensor.__add__.0.forward.output.0.npy
173
+ | | | | ...
174
+ | | | | └── Jit.AlexNet.0.forward.input.0.npy
175
+ │ | | ├── dump.json # 保存前反向算子、算子的统计量信息或溢出算子信息。包含dump数据的API名称(命名格式为:`{api_type}_{api_name}_{API调用次数}_{前向反向}_{input/output}.{参数序号}`)、dtype、 shape、各数据的max、min、mean、L2norm统计信息以及当配置summary_mode="md5"时的md5数据。其中,“参数序号”表示该API下的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个子参数;L2norm表示L2范数(平方根)
176
+ │ | | ├── stack.json # 算子调用栈信息
177
+ │ | | └── construct.json # 分层分级结构,level为L1时,construct.json内容为空
178
+ │ | ├── rank1
179
+ | | | ├── dump_tensor_data
180
+ | | | | └── ...
181
+ │ | | ├── dump.json
182
+ │ | | ├── stack.json
183
+ | | | └── construct.json
184
+ │ | ├── ...
185
+ │ | |
186
+ | | └── rank7
187
+ │ ├── step1
188
+ │ | ├── ...
189
+ │ ├── step2
190
+ ```
191
+
192
+ dump过程中,npy文件在对应算子或者模块被执行后就会落盘,而json文件则需要在正常执行PrecisionDebugger.stop()后才会写入完整数据,异常的程序终止会保存终止前被执行算子的相关npy文件,可能会导致json文件中数据丢失。
193
+
194
+ 其中rank为设备上各卡的ID,每张卡上dump的数据会生成对应dump目录。非分布式场景下没有rank ID,目录名称为rank。
195
+
196
+ 动态图场景下使能PSJit或PIJit,装饰特定Cell或function,被装饰的部分会全部/部分使能静态图流程。PSJit场景下config.json文件配置level为L1时,被PSJit装饰的部分也作为API被dump到对应目录;若配置level为L2时,则只会dump用户网络中静态图流程下的相关kernel。PIJit场景开启dump工具后,会被还原为动态图,按API粒度进行dump。
197
+
198
+ npy文件保存的前缀和MindSpore对应关系如下:
199
+
200
+ | 前缀 | MindSpore模块 |
201
+ | -------------- | ---------------------------- |
202
+ | Tensor | mindspore.Tensor |
203
+ | Functional | mindspore.ops |
204
+ | Mint | mindspore.mint |
205
+ | MintFunctional | mindspore.mint.nn.functional |
206
+ | Jit | mindspore.jit |
207
+
208
+ ## 工具支持的API列表
209
+
210
+ msprobe工具维护固定的API支持列表,若需要删除或增加dump的API,可以在msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml文件内手动修改,如下示例:
211
+
212
+ ```bash
213
+ ops: # ops为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API
214
+ - adaptive_avg_pool1d
215
+ - adaptive_avg_pool2d
216
+ - adaptive_avg_pool3d
217
+ ```
@@ -1,24 +1,25 @@
1
+ from msprobe.mindspore.common.const import Const
1
2
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
2
- from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump
3
+ from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
3
4
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
4
5
 
5
6
 
6
7
  class DumpToolFactory:
7
8
  tools = {
8
- "cell": {
9
- "kbk": None,
10
- "graph": None,
11
- "pynative": None
9
+ Const.CELL: {
10
+ Const.GRAPH_KBYK_MODE: None,
11
+ Const.GRAPH_GE_MODE: None,
12
+ Const.PYNATIVE_MODE: None
12
13
  },
13
- "api": {
14
- "kbk": ApiKbkDump,
15
- "graph": None,
16
- "pynative": None
14
+ Const.API: {
15
+ Const.GRAPH_KBYK_MODE: None,
16
+ Const.GRAPH_GE_MODE: None,
17
+ Const.PYNATIVE_MODE: None
17
18
  },
18
- "kernel": {
19
- "kbk": None,
20
- "graph": KernelGraphDump,
21
- "pynative": None
19
+ Const.KERNEL: {
20
+ Const.GRAPH_KBYK_MODE: KernelKbykDump,
21
+ Const.GRAPH_GE_MODE: KernelGraphDump,
22
+ Const.PYNATIVE_MODE: KernelKbykDump
22
23
  }
23
24
  }
24
25
 
@@ -26,13 +27,9 @@ class DumpToolFactory:
26
27
  def create(config: DebuggerConfig):
27
28
  tool = DumpToolFactory.tools.get(config.level)
28
29
  if not tool:
29
- raise Exception("valid level is needed.")
30
- if config.level == "api":
31
- tool = tool.get("kbk")
32
- elif config.level == "kernel":
33
- tool = tool.get("graph")
34
- elif config.level == "cell":
35
- raise Exception("Cell dump in not supported now.")
30
+ raise Exception("Valid level is needed.")
31
+ tool = tool.get(config.execution_mode)
36
32
  if not tool:
37
- raise Exception("Data dump in not supported in this mode.")
38
- return tool(config)
33
+ raise Exception(f"Data dump is not supported in {config.execution_mode} mode "
34
+ f"when dump level is {config.level}.")
35
+ return tool(config)
@@ -0,0 +1,104 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import mindspore as ms
17
+ from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
18
+ HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
19
+ from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
20
+ from msprobe.core.common.utils import Const
21
+
22
+
23
+ class ApiRegistry:
24
+ def __init__(self):
25
+ self.tensor_ori_attr = {}
26
+ self.functional_ori_attr = {}
27
+ self.mint_ops_ori_attr = {}
28
+ self.mint_func_ops_ori_attr = {}
29
+ self.norm_inner_ops_ori_attr = {}
30
+
31
+ self.tensor_hook_attr = {}
32
+ self.functional_hook_attr = {}
33
+ self.mint_ops_hook_attr = {}
34
+ self.mint_func_ops_hook_attr = {}
35
+ self.norm_inner_ops_hook_attr = {}
36
+
37
+ self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
38
+
39
+ @staticmethod
40
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
41
+ for api in api_list:
42
+ if Const.SEP in api:
43
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
44
+ sub_module = getattr(ori_api_group, sub_module_name)
45
+ api_ori_attr[api] = getattr(sub_module, sub_op)
46
+ else:
47
+ api_ori_attr[api] = getattr(ori_api_group, api)
48
+
49
+ @staticmethod
50
+ def set_api_attr(api_group, attr_dict):
51
+ for api, api_attr in attr_dict.items():
52
+ if Const.SEP in api:
53
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
54
+ sub_module = getattr(api_group, sub_module_name, None)
55
+ if sub_module is not None:
56
+ setattr(sub_module, sub_op, api_attr)
57
+ else:
58
+ setattr(api_group, api, api_attr)
59
+
60
+ def norm_inner_op_set_hook_func(self):
61
+ self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr)
62
+
63
+ def norm_inner_op_set_ori_func(self):
64
+ self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr)
65
+
66
+ def api_set_hook_func(self):
67
+ self.set_api_attr(ms.Tensor, self.tensor_hook_attr)
68
+ self.set_api_attr(ms.ops, self.functional_hook_attr)
69
+ self.set_api_attr(ms.mint, self.mint_ops_hook_attr)
70
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr)
71
+
72
+ def api_set_ori_func(self):
73
+ self.set_api_attr(ms.Tensor, self.tensor_ori_attr)
74
+ self.set_api_attr(ms.ops, self.functional_ori_attr)
75
+ self.set_api_attr(ms.mint, self.mint_ops_ori_attr)
76
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr)
77
+
78
+ def initialize_hook(self, hook):
79
+ self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr)
80
+ wrap_tensor_ops_and_bind(hook)
81
+ for attr_name in dir(HOOKTensor):
82
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
83
+ self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name)
84
+
85
+ functional_ops, mint_ops, mint_func_ops = get_functional_ops()
86
+ self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
87
+ self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr)
88
+ self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr)
89
+ self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr)
90
+ setup_hooks(hook)
91
+ for attr_name in dir(HOOKFunctionalOP):
92
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
93
+ self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
94
+ if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops:
95
+ self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
96
+ for attr_name in dir(HOOKMintOP):
97
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
98
+ self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name)
99
+ for attr_name in dir(HOOKMintNNFunctionalOP):
100
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
101
+ self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
102
+
103
+
104
+ api_register = ApiRegistry()
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ from collections import defaultdict
16
+
17
+ from mindspore import nn
18
+ from msprobe.core.common.const import Const
19
+
20
+
21
+ class HOOKCell(nn.Cell):
22
+ cell_count = defaultdict(int)
23
+ g_stop_hook = False
24
+
25
+ def __init__(self, build_hook) -> None:
26
+ super(HOOKCell, self).__init__()
27
+ self.changed_status = False
28
+ self.input_kwargs = {}
29
+ self.prefix = ""
30
+ if not HOOKCell.g_stop_hook:
31
+ HOOKCell.g_stop_hook = True
32
+ self.changed_status = True
33
+ if hasattr(self, "prefix_op_name_"):
34
+ self.prefix = self.prefix_op_name_
35
+
36
+ HOOKCell.cell_count[self.prefix] += 1
37
+ self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
38
+ forward_hook, backward_hook = build_hook(self.prefix)
39
+ self.register_forward_hook(forward_hook)
40
+ self.register_backward_hook(backward_hook)
41
+
42
+ # 重载call,加全局标志。
43
+ def __call__(self, *args, **kwargs):
44
+ try:
45
+ self.input_kwargs = kwargs
46
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
47
+ except Exception as e:
48
+ raise e
49
+ finally:
50
+ if self.changed_status:
51
+ self.changed_status = False
52
+ HOOKCell.g_stop_hook = False
53
+ return out