mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,6 +1,12 @@
1
1
  import json
2
+
2
3
  from msprobe.core.common_config import CommonConfig, BaseConfig
3
4
  from msprobe.core.common.file_check import FileOpen
5
+ from msprobe.core.common.const import Const
6
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
7
+ from msprobe.mindspore.common.log import logger
8
+ from msprobe.core.grad_probe.constant import level_adp
9
+ from msprobe.core.grad_probe.utils import check_numeral_list_ascend
4
10
 
5
11
 
6
12
  class TensorConfig(BaseConfig):
@@ -31,39 +37,81 @@ class StatisticsConfig(BaseConfig):
31
37
  if self.data_mode is not None and len(self.data_mode) > 0:
32
38
  if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
33
39
  raise Exception("data_mode must be all, input or output")
40
+ if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
41
+ raise Exception("summary_mode is invalid")
34
42
 
35
43
 
36
- class OverflowCheck(BaseConfig):
44
+ class OverflowCheckConfig(BaseConfig):
37
45
  def __init__(self, json_config):
38
46
  super().__init__(json_config)
39
- self.file_format = None
40
- self.check_mode = json_config.get("check_mode")
47
+ self.data_mode = ["all"]
41
48
  self._check_config()
42
49
 
43
50
  def _check_config(self):
44
- if self.data_mode is not None and len(self.data_mode) > 0:
45
- if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
46
- raise Exception("data_mode must be all, input or output")
51
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
52
+ raise Exception("overflow_nums is invalid, it should be an integer")
53
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
54
+ raise Exception("overflow_nums should be -1 or positive integer")
47
55
  if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]:
48
56
  raise Exception("check_mode is invalid")
49
57
 
50
58
 
59
+ class FreeBenchmarkConfig(BaseConfig):
60
+ def __init__(self, task_config):
61
+ super().__init__(task_config)
62
+ self._check_config()
63
+
64
+ def _check_config(self):
65
+ if self.fuzz_device and self.fuzz_device not in FreeBenchmarkConst.DEVICE_LIST:
66
+ raise Exception("fuzz_device must be npu or empty")
67
+ if self.pert_mode and self.pert_mode not in FreeBenchmarkConst.PERT_TYPE_LIST:
68
+ raise Exception("pert_mode must be improve_precision, add_noise, bit_noise, no_change or empty")
69
+ if self.handler_type and self.handler_type not in FreeBenchmarkConst.HANDLER_TYPE_LIST:
70
+ raise Exception("handler_type must be check, fix or empty")
71
+ if self.fuzz_level and self.fuzz_level not in FreeBenchmarkConst.DUMP_LEVEL_LIST:
72
+ raise Exception("fuzz_level must be L1 or empty")
73
+ if self.fuzz_stage and self.fuzz_stage not in FreeBenchmarkConst.STAGE_LIST:
74
+ raise Exception("fuzz_stage must be forward or empty")
75
+ if self.if_preheat or self.preheat_step or self.max_sample:
76
+ logger.warning("'if_preheat', 'preheat_step' and 'max_sample' settings "
77
+ "are not supported for mindspore free benchmark task.")
78
+
79
+
80
+ class GradProbeConfig(BaseConfig):
81
+ def __init__(self, json_config):
82
+ super().__init__(json_config)
83
+ self.grad_level = json_config.get("grad_level", "L1")
84
+ self.param_list = json_config.get("param_list", [])
85
+ self.bounds = json_config.get("bounds", [])
86
+
87
+ def _check_config(self):
88
+ if self.grad_level not in level_adp.keys():
89
+ raise Exception(f"grad_level must be one of {level_adp.keys()}")
90
+ if not isinstance(self.param_list, list):
91
+ raise Exception(f"param_list must be a list")
92
+ check_numeral_list_ascend(self.bounds)
93
+
94
+
95
+ TaskDict = {
96
+ Const.TENSOR: TensorConfig,
97
+ Const.STATISTICS: StatisticsConfig,
98
+ Const.OVERFLOW_CHECK: OverflowCheckConfig,
99
+ Const.FREE_BENCHMARK: FreeBenchmarkConfig,
100
+ Const.GRAD_PROBE: GradProbeConfig,
101
+ }
102
+
103
+
51
104
  def parse_common_config(json_config):
52
105
  return CommonConfig(json_config)
53
106
 
54
107
 
55
108
  def parse_task_config(task, json_config):
56
- task_map = json_config[task]
109
+ task_map = json_config.get(task)
57
110
  if not task_map:
58
111
  task_map = dict()
59
- if task == "tensor":
60
- return TensorConfig(task_map)
61
- elif task == "statistics":
62
- return StatisticsConfig(task_map)
63
- elif task == "overflow_check":
64
- return OverflowCheck(task_map)
65
- else:
112
+ if task not in TaskDict:
66
113
  raise Exception("task is invalid.")
114
+ return TaskDict.get(task)(task_map)
67
115
 
68
116
 
69
117
  def parse_json_config(json_file_path):
@@ -73,6 +121,6 @@ def parse_json_config(json_file_path):
73
121
  json_config = json.load(file)
74
122
  common_config = parse_common_config(json_config)
75
123
  if not common_config.task:
76
- common_config.task = "statistics"
124
+ common_config.task = Const.STATISTICS
77
125
  task_config = parse_task_config(common_config.task, json_config)
78
126
  return common_config, task_config
@@ -1,23 +1,24 @@
1
+ from msprobe.mindspore.common.const import Const
1
2
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
2
3
  from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
3
4
 
4
5
 
5
6
  class OverflowCheckToolFactory:
6
7
  tools = {
7
- "cell": {
8
- "kbk": None,
9
- "graph": None,
10
- "pynative": None
8
+ Const.CELL: {
9
+ Const.GRAPH_KBYK_MODE: None,
10
+ Const.GRAPH_GE_MODE: None,
11
+ Const.PYNATIVE_MODE: None
11
12
  },
12
- "api": {
13
- "kbk": None,
14
- "graph": None,
15
- "pynative": None
13
+ Const.API: {
14
+ Const.GRAPH_KBYK_MODE: None,
15
+ Const.GRAPH_GE_MODE: None,
16
+ Const.PYNATIVE_MODE: None
16
17
  },
17
- "kernel": {
18
- "kbk": None,
19
- "graph": KernelGraphOverflowCheck,
20
- "pynative": None
18
+ Const.KERNEL: {
19
+ Const.GRAPH_KBYK_MODE: None,
20
+ Const.GRAPH_GE_MODE: KernelGraphOverflowCheck,
21
+ Const.PYNATIVE_MODE: None
21
22
  }
22
23
  }
23
24
 
@@ -25,8 +26,9 @@ class OverflowCheckToolFactory:
25
26
  def create(config: DebuggerConfig):
26
27
  tool = OverflowCheckToolFactory.tools.get(config.level)
27
28
  if not tool:
28
- raise Exception("valid level is needed.")
29
- tool = tool.get("graph")
29
+ raise Exception("Valid level is needed.")
30
+ tool = tool.get(config.execution_mode)
30
31
  if not tool:
31
- raise Exception("Overflow check in not supported in this mode.")
32
+ raise Exception(f"Overflow check is not supported in {config.execution_mode} mode "
33
+ f"when level is {config.level}.")
32
34
  return tool(config)
@@ -0,0 +1,4 @@
1
+ class Runtime:
2
+ step_count: int = 0
3
+ rank_id: int = -1
4
+ is_running: bool = False
@@ -0,0 +1,354 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import copy
18
+ from pathlib import Path
19
+ import functools
20
+ from collections import defaultdict
21
+
22
+ import mindspore as ms
23
+ from mindspore.common.tensor import Tensor
24
+ from mindspore import ops
25
+ from mindspore import nn
26
+ try:
27
+ from mindspore.common._pijit_context import PIJitCaptureContext
28
+ pijit_label = True
29
+ except ImportError:
30
+ pijit_label = False
31
+
32
+
33
+ from msprobe.core.data_dump.data_collector import build_data_collector
34
+ from msprobe.core.data_dump.scope import BaseScope
35
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
36
+ from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create
37
+ from msprobe.mindspore.common.log import logger
38
+ from msprobe.core.common.utils import Const
39
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
40
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
41
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
42
+ ModuleBackwardInputs, ModuleBackwardOutputs
43
+ from msprobe.core.common.exceptions import MsprobeException
44
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
45
+ from msprobe.mindspore.cell_processor import CellProcessor
46
+ from msprobe.mindspore.dump.jit_dump import JitDump
47
+
48
+
49
+ class Service:
50
+ def __init__(self, config):
51
+ self.model = None
52
+ self.config = copy.deepcopy(config)
53
+ self.config.level = self.config.level_ori
54
+ self.data_collector = build_data_collector(self.config)
55
+ self.cell_processor = CellProcessor(self.data_collector.scope)
56
+ self.switch = False
57
+ self.current_iter = 0
58
+ self.first_start = True
59
+ self.current_rank = None
60
+ self.primitive_counters = {}
61
+ self.dump_iter_dir = None
62
+ self.start_call = False
63
+ self.check_level_valid()
64
+
65
+ @staticmethod
66
+ def check_model_valid(model):
67
+ if not model or isinstance(model, nn.Cell):
68
+ return model
69
+ raise MsprobeException(
70
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
71
+ )
72
+
73
+ def check_level_valid(self):
74
+ if self.config.level == "L2":
75
+ raise MsprobeException(
76
+ MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
+ )
78
+
79
+ def build_hook(self, target_type, name):
80
+ def forward_hook(api_or_cell_name, cell, input, output):
81
+ if target_type == BaseScope.Module_Type_Module:
82
+ api_or_cell_name = cell.mindstudio_reserved_name
83
+ self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
84
+ if not self.switch:
85
+ return None
86
+ if self.data_collector:
87
+ if target_type == BaseScope.Module_Type_Module:
88
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
89
+ else:
90
+ module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs, output=output)
91
+ self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
92
+ if self.data_collector.if_return_forward_new_output():
93
+ return self.data_collector.get_forward_new_output()
94
+ if target_type == BaseScope.Module_Type_API:
95
+ del cell.input_kwargs
96
+ return output
97
+
98
+ def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
99
+ if target_type == BaseScope.Module_Type_Module:
100
+ api_or_cell_name = cell.mindstudio_reserved_name
101
+ self.data_collector.visit_and_clear_overflow_status(api_or_cell_name)
102
+ if not self.switch:
103
+ return
104
+ if self.data_collector:
105
+ # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
106
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
107
+ self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
108
+
109
+ pid = os.getpid()
110
+ forward_name_template = name + Const.FORWARD
111
+ backward_name_template = name + Const.BACKWARD
112
+ forward_hook = functools.partial(forward_hook, forward_name_template)
113
+ backward_hook = functools.partial(backward_hook, backward_name_template)
114
+
115
+ def wrap_forward_hook(cell, input, output):
116
+ return forward_hook(cell, input, output)
117
+
118
+ def wrap_backward_hook(cell, grad_input, grad_output):
119
+ return backward_hook(cell, grad_input, grad_output)
120
+
121
+ return wrap_forward_hook, wrap_backward_hook
122
+
123
+ def wrap_primitive(self, origin_func, primitive_name):
124
+ service_instance = self
125
+
126
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
127
+ def backward_hook(grad):
128
+ captured_grads.append(grad)
129
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
130
+ try:
131
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
132
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
133
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
134
+ service_instance.data_collector.backward_output_data_collect(
135
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
136
+ )
137
+ captured_grads.clear()
138
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
139
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
140
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
141
+ service_instance.data_collector.backward_input_data_collect(
142
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
143
+ )
144
+ captured_grads.clear()
145
+
146
+ except Exception as exception:
147
+ raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
148
+ f" updated_primitive_name: {updated_primitive_name}") from exception
149
+
150
+ return backward_hook
151
+
152
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
153
+ hooked_inputs = []
154
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
155
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
156
+ Const.INPUT)
157
+ for _, arg in enumerate(args):
158
+ if isinstance(arg, Tensor):
159
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
160
+ hooked_inputs.append(arg_hooked)
161
+ else:
162
+ hooked_inputs.append(arg)
163
+ return hooked_inputs
164
+
165
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
166
+ if isinstance(out, tuple):
167
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
168
+ else:
169
+ num_output_tensors = 1
170
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
171
+ updated_primitive_name, Const.OUTPUT)
172
+
173
+ if isinstance(out, Tensor):
174
+ return ops.HookBackward(output_backward_hook)(out)
175
+ elif isinstance(out, tuple):
176
+ hooked_outputs = []
177
+ for tensor in out:
178
+ if isinstance(tensor, Tensor):
179
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
180
+ else:
181
+ hooked_outputs.append(tensor)
182
+ return tuple(hooked_outputs)
183
+ return out
184
+
185
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
186
+ service_instance.update_primitive_counters(primitive_name)
187
+ current_count = service_instance.primitive_counters.get(primitive_name, 0)
188
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
189
+
190
+ if not service_instance.switch:
191
+ return origin_func(*args, **kwargs)
192
+
193
+ captured_grads_input, captured_grads_output = [], []
194
+
195
+ try:
196
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
197
+ except Exception as exception:
198
+ raise Exception("This is a primitive op dump error during input hooking: {},"
199
+ " primitive_name: {}".format(exception, primitive_name)) from exception
200
+
201
+ try:
202
+ out = origin_func(*hooked_inputs, **kwargs)
203
+ except Exception as exception:
204
+ raise Exception("This is a primitive op dump error during function call: {},"
205
+ " primitive_name: {}".format(exception, primitive_name)) from exception
206
+
207
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
208
+ service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
209
+ if service_instance.data_collector:
210
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
211
+ try:
212
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
213
+ os.getpid(), module_input_output)
214
+ except Exception as exception:
215
+ raise Exception("This is a primitive op dump error during forward data collection: {},"
216
+ " primitive_name: {}".format(exception, primitive_name)) from exception
217
+
218
+ if service_instance.data_collector.if_return_forward_new_output():
219
+ out = service_instance.data_collector.get_forward_new_output()
220
+
221
+ try:
222
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
223
+ except Exception as exception:
224
+ raise Exception("This is a primitive op dump error during output hooking: {},"
225
+ " primitive_name: {}".format(exception, primitive_name)) from exception
226
+
227
+ return out
228
+
229
+ return wrapped_primitive_call
230
+
231
+ def update_primitive_counters(self, primitive_name):
232
+ if primitive_name not in self.primitive_counters:
233
+ self.primitive_counters[primitive_name] = 0
234
+ else:
235
+ self.primitive_counters[primitive_name] += 1
236
+
237
+ def register_hooks(self):
238
+ primitive_set = set()
239
+ for _, cell in self.model.cells_and_names():
240
+ for pname, primitive in cell._primitives.items():
241
+ primitive_set.add((pname, primitive))
242
+
243
+ for pname, primitive in primitive_set:
244
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,),
245
+ {'__call__': self.wrap_primitive(primitive.__call__, pname)})
246
+ primitive.__class__ = NewPrimitive
247
+
248
+ def step(self):
249
+ self.current_iter += 1
250
+ self.data_collector.update_iter(self.current_iter)
251
+ HOOKCell.cell_count = defaultdict(int)
252
+ CellProcessor.cell_count = {}
253
+ self.primitive_counters.clear()
254
+
255
+ def start(self, model=None):
256
+ self.model = self.check_model_valid(model)
257
+ self.start_call = True
258
+ logger.info("msprobe: debugger.start() is set successfully")
259
+ if self.config.step and self.current_iter > max(self.config.step):
260
+ self.stop()
261
+ raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
262
+ if self.config.step and self.current_iter not in self.config.step:
263
+ return
264
+ if self.first_start:
265
+ try:
266
+ self.current_rank = get_rank_if_initialized()
267
+ except DistributedNotInitializedError:
268
+ self.current_rank = None
269
+
270
+ if self.config.rank and self.current_rank not in self.config.rank:
271
+ return
272
+ self.register_hook_new()
273
+ self.first_start = False
274
+ self.switch = True
275
+ logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
276
+ self.create_dirs()
277
+ logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
278
+ if self.config.level == "L1":
279
+ JitDump.set_config(self.config)
280
+ JitDump.set_data_collector(self.data_collector)
281
+ ms.common.api._MindsporeFunctionExecutor = JitDump
282
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
283
+ if pijit_label:
284
+ PIJitCaptureContext.__enter__ = self.empty
285
+ PIJitCaptureContext.__exit__ = self.empty
286
+
287
+ def stop(self):
288
+ logger.info("msprobe: debugger.stop() is set successfully. "
289
+ "Please set debugger.start() to turn on the dump switch again. ")
290
+ if not self.start_call:
291
+ logger.error("msprobe: debugger.start() is not set in the current scope.")
292
+ raise Exception("debugger.start() is not set in the current scope.")
293
+ if self.config.step and self.current_iter not in self.config.step:
294
+ return
295
+ if self.config.rank and self.current_rank not in self.config.rank:
296
+ return
297
+ self.switch = False
298
+ self.start_call = False
299
+ self.data_collector.write_json()
300
+
301
+ def create_dirs(self):
302
+ check_path_before_create(self.config.dump_path)
303
+ if not os.path.exists(self.config.dump_path):
304
+ Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
305
+ file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
306
+ file_check.common_check()
307
+ self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
308
+ cur_rank = self.current_rank if self.current_rank is not None else ''
309
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
310
+ if not os.path.exists(dump_dir):
311
+ Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
312
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
313
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
314
+ Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
315
+ else:
316
+ dump_data_dir = None
317
+
318
+ dump_file_path = os.path.join(dump_dir, "dump.json")
319
+ stack_file_path = os.path.join(dump_dir, "stack.json")
320
+ construct_file_path = os.path.join(dump_dir, "construct.json")
321
+ self.data_collector.update_dump_paths(
322
+ dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
323
+
324
+ def empty(self, *args, **kwargs):
325
+ pass
326
+
327
+ def register_hook_new(self):
328
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
329
+ if self.config.level == "L1":
330
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
331
+ api_register.api_set_hook_func()
332
+ if self.model:
333
+ self.register_hooks()
334
+
335
+ if self.config.level == "L0":
336
+ if not self.model:
337
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, "The current level is L0, the model cannot be None")
338
+ for name, cell in self.model.cells_and_names():
339
+ if cell == self.model:
340
+ continue
341
+ prefix = 'Cell' + Const.SEP + name + Const.SEP + \
342
+ cell.__class__.__name__ + Const.SEP
343
+ forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
344
+ cell.register_forward_hook(forward_hook)
345
+ cell.register_backward_hook(backward_hook)
346
+
347
+ cell.register_forward_pre_hook(
348
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
349
+ cell.register_forward_hook(
350
+ self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
351
+ cell.register_backward_pre_hook(
352
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
353
+ cell.register_backward_hook(
354
+ self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
@@ -1,20 +1,23 @@
1
+ from msprobe.core.common.const import Const
1
2
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
2
3
  from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
3
4
  from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory
5
+ from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory
4
6
 
5
7
 
6
8
  class TaskHandlerFactory:
7
9
  tasks = {
8
- "tensor": DumpToolFactory,
9
- "statistics": DumpToolFactory,
10
- "overflow_check": OverflowCheckToolFactory
10
+ Const.TENSOR: DumpToolFactory,
11
+ Const.STATISTICS: DumpToolFactory,
12
+ Const.OVERFLOW_CHECK: OverflowCheckToolFactory,
13
+ Const.FREE_BENCHMARK: SelfCheckToolFactory
11
14
  }
12
15
 
13
16
  @staticmethod
14
17
  def create(config: DebuggerConfig):
15
18
  task = TaskHandlerFactory.tasks.get(config.task)
16
19
  if not task:
17
- raise Exception("valid task is needed.")
20
+ raise Exception("Valid task is needed.")
18
21
  handler = task.create(config)
19
22
  if not handler:
20
23
  raise Exception("Can not find task handler")