mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,90 @@
1
+ import os
2
+ from collections import defaultdict
3
+
4
+ import torch
5
+ if int(torch.__version__.split('.')[0]) >= 2:
6
+ from torch.optim.optimizer import register_optimizer_step_pre_hook
7
+ from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
8
+ from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
9
+ from msprobe.core.grad_probe.constant import GradConst, level_adp
10
+ from msprobe.core.common.file_check import create_directory
11
+ from msprobe.core.common.log import logger
12
+ from msprobe.core.common.utils import remove_path, write_csv, save_npy
13
+ from msprobe.pytorch.common.utils import get_rank_id, print_rank_0, save_pt
14
+
15
+
16
+ class GradientMonitor:
17
+
18
+ def __init__(self, common_config, task_config):
19
+ level = task_config.grad_level
20
+ if level not in level_adp:
21
+ raise Exception(f"level is valid, not in {level_adp.keys()}")
22
+ self._level_adp = level_adp[level]
23
+ self._param_list = task_config.param_list
24
+ self._target_ranks = common_config.rank
25
+ logger.info(f"target rank {self._target_ranks}")
26
+ self._target_step = common_config.step
27
+ logger.info(f"target step {self._target_step}")
28
+ self._bounds = task_config.bounds
29
+ check_numeral_list_ascend(self._bounds)
30
+ self._output_path = common_config.dump_path
31
+ if not os.path.exists(self._output_path):
32
+ create_directory(self._output_path)
33
+ else:
34
+ logger.warning(f"the file in {self._output_path} will be recoverd")
35
+ self._step = -1
36
+ self._param2name = defaultdict(str)
37
+
38
+ @property
39
+ def output_path(self):
40
+ return self._output_path
41
+
42
+ @staticmethod
43
+ def save_grad_direction(param_name, grad, save_path):
44
+ if not os.path.exists(save_path):
45
+ create_directory(save_path)
46
+ param_grad = grad.clone().detach()
47
+ is_positive = param_grad > 0
48
+ save_filepath = os.path.join(save_path, f"{param_name}.npy")
49
+ save_npy(is_positive.numpy(), save_filepath)
50
+
51
+ def monitor(self, model):
52
+ print_rank_0("> parameter names:")
53
+ for name, param in model.named_parameters():
54
+ self._param2name[param] = name
55
+ print_rank_0(f"\t{name}")
56
+ setattr(self, "_rank", get_rank_id())
57
+ if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks):
58
+ return
59
+ self._hook_optimizer()
60
+
61
+ def _hook_optimizer(self):
62
+ def optimizer_pre_step_hook(optimizer, args, kargs):
63
+ self._step += 1
64
+ if not data_in_list_target(self._step, self._target_step):
65
+ return
66
+ output_lines = []
67
+ for param, param_name in self._param2name.items():
68
+ if not data_in_list_target(param_name, self._param_list):
69
+ continue
70
+ grad = param.main_grad if hasattr(param, "main_grad") else param.grad
71
+ if grad is None:
72
+ logger.info(f"grad is None: {param_name}")
73
+ continue
74
+ grad_info = GradStatCsv.generate_csv_line(param_name, self._level_adp, grad, self._bounds)
75
+ output_lines.append(grad_info)
76
+ if self._level_adp["have_grad_direction"]:
77
+ GradientMonitor.save_grad_direction(param_name, grad,
78
+ f'{self._output_path}/rank{self._rank}/step{self._step}')
79
+ output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
80
+ if not os.path.isdir(output_dirpath):
81
+ create_directory(output_dirpath)
82
+ output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
83
+ if os.path.exists(output_path):
84
+ logger.warning(f"{output_path} will be recoverd")
85
+ remove_path(output_path)
86
+ header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
87
+ output_lines.insert(0, header_result)
88
+ write_csv(output_lines, output_path)
89
+ if int(torch.__version__.split('.')[0]) >= 2:
90
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
@@ -0,0 +1,129 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections import namedtuple
3
+ import hashlib
4
+ import torch
5
+ from msprobe.core.grad_probe.constant import GradConst
6
+
7
+ CSV_header_input = namedtuple("CSV_header_input", ["bounds"])
8
+ CSV_content_input = namedtuple("CSV_content_input", ["grad", "bounds"])
9
+
10
+
11
+ class GradStatCsv:
12
+ csv = {}
13
+
14
+ @staticmethod
15
+ def generate_csv_header(level, bounds):
16
+ header = ["param_name"]
17
+ for key in level["header"]:
18
+ csv_header_input = CSV_header_input(bounds=bounds)
19
+ header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input))
20
+ return header
21
+
22
+ @staticmethod
23
+ def generate_csv_line(param_name, level, grad, bounds):
24
+ line = [param_name]
25
+ for key in level["header"]:
26
+ csv_content_input = CSV_content_input(grad=grad, bounds=bounds)
27
+ line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input))
28
+ return line
29
+
30
+
31
+ def register_csv_item(key, cls=None):
32
+ if cls is None:
33
+ # 无参数时,返回装饰器函数
34
+ return lambda cls: register_csv_item(key, cls)
35
+ GradStatCsv.csv[key] = cls
36
+ return cls
37
+
38
+
39
+ class CsvItem(ABC):
40
+ @abstractmethod
41
+ def generate_csv_header(csv_header_input):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def generate_csv_content(csv_content_input):
46
+ pass
47
+
48
+
49
+ @register_csv_item(GradConst.MD5)
50
+ class CSV_md5(CsvItem):
51
+ def generate_csv_header(csv_header_input):
52
+ return ["MD5"]
53
+
54
+ def generate_csv_content(csv_content_input):
55
+ grad = csv_content_input.grad
56
+ tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
57
+ md5_hash = hashlib.md5(tensor_bytes)
58
+ return [md5_hash.hexdigest()]
59
+
60
+
61
+ @register_csv_item(GradConst.DISTRIBUTION)
62
+ class CSV_distribution(CsvItem):
63
+ def generate_csv_header(csv_header_input):
64
+ bounds = csv_header_input.bounds
65
+ intervals = []
66
+ if bounds:
67
+ intervals.append(f"(-inf, {bounds[0]}]")
68
+ for i in range(1, len(bounds)):
69
+ intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
70
+ if intervals:
71
+ intervals.append(f"({bounds[-1]}, inf)")
72
+ intervals.append("=0")
73
+
74
+ return intervals
75
+
76
+ def generate_csv_content(csv_content_input):
77
+ grad = csv_content_input.grad
78
+ bounds = csv_content_input.bounds
79
+ grad = grad.cpu().detach()
80
+ if grad.dtype == torch.bfloat16:
81
+ grad = grad.to(torch.float32)
82
+ element_num = grad.numel()
83
+ grad_equal_0_num = (grad == 0).sum().item()
84
+ bound = torch.Tensor(bounds)
85
+ bucketsize_result = torch.bucketize(grad, bound)
86
+ interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bound) + 1)]
87
+ interval_nums.append(grad_equal_0_num)
88
+ return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums]
89
+ return return_list
90
+
91
+
92
+ @register_csv_item(GradConst.MAX)
93
+ class CSV_max(CsvItem):
94
+ def generate_csv_header(csv_header_input):
95
+ return ["max"]
96
+
97
+ def generate_csv_content(csv_content_input):
98
+ grad = csv_content_input.grad
99
+ return [torch.max(grad).cpu().detach().float().numpy().tolist()]
100
+
101
+
102
+ @register_csv_item(GradConst.MIN)
103
+ class CSV_max(CsvItem):
104
+ def generate_csv_header(csv_header_input):
105
+ return ["min"]
106
+
107
+ def generate_csv_content(csv_content_input):
108
+ grad = csv_content_input.grad
109
+ return [torch.min(grad).cpu().detach().float().numpy().tolist()]
110
+
111
+
112
+ @register_csv_item(GradConst.NORM)
113
+ class CSV_max(CsvItem):
114
+ def generate_csv_header(csv_header_input):
115
+ return ["norm"]
116
+
117
+ def generate_csv_content(csv_content_input):
118
+ grad = csv_content_input.grad
119
+ return [torch.norm(grad).cpu().detach().float().numpy().tolist()]
120
+
121
+
122
+ @register_csv_item(GradConst.SHAPE)
123
+ class CSV_shape(CsvItem):
124
+ def generate_csv_header(csv_header_input):
125
+ return ["shape"]
126
+
127
+ def generate_csv_content(csv_content_input):
128
+ grad = csv_content_input.grad
129
+ return [list(grad.shape)]
@@ -17,10 +17,13 @@
17
17
 
18
18
  import functools
19
19
  import threading
20
+
20
21
  import torch
21
22
  import torch.nn as nn
22
23
  import torch.utils.hooks as full_hooks
24
+
23
25
  from msprobe.core.common.const import Const
26
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
27
 
25
28
 
26
29
  class HOOKModule(nn.Module):
@@ -46,9 +49,13 @@ class HOOKModule(nn.Module):
46
49
  else:
47
50
  HOOKModule.module_count[self.prefix] += 1
48
51
  self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
49
- forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix)
50
- self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
51
- self.register_forward_hook(forward_hook, with_kwargs=True)
52
+ forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
53
+ if torch_version_above_or_equal_2:
54
+ self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
55
+ self.register_forward_hook(forward_hook, with_kwargs=True)
56
+ else:
57
+ self.register_forward_pre_hook(forward_pre_hook)
58
+ self.register_forward_hook(forward_hook)
52
59
  self.register_backward_hook(backward_hook)
53
60
 
54
61
  def __call__(self, *input, **kwargs):
@@ -61,6 +68,10 @@ class HOOKModule(nn.Module):
61
68
  HOOKModule.inner_stop_hook[self.current_thread] = False
62
69
  return result
63
70
 
71
+ @classmethod
72
+ def reset_module_stats(cls):
73
+ cls.module_count = {}
74
+
64
75
  def _call_func(self, *input, **kwargs):
65
76
  full_backward_hooks, non_full_backward_hooks = [], []
66
77
  if len(self._backward_hooks) > 0:
@@ -1873,4 +1873,5 @@ distributed:
1873
1873
  - reduce_scatter
1874
1874
  - _reduce_scatter_base
1875
1875
  - _all_gather_base
1876
- - all_to_all_single
1876
+ - all_to_all_single
1877
+ - all_to_all
@@ -16,14 +16,14 @@
16
16
  """
17
17
 
18
18
  import os
19
- import yaml
19
+ from msprobe.core.common.utils import load_yaml
20
20
 
21
- from msprobe.core.common.file_check import FileOpen
22
21
 
23
- cur_path = os.path.dirname(os.path.realpath(__file__))
24
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
25
- with FileOpen(yaml_path, 'r') as f:
26
- Ops = yaml.safe_load(f)
27
- WrapFunctionalOps = Ops.get('functional')
28
- WrapTensorOps = Ops.get('tensor')
29
- WrapTorchOps = Ops.get('torch')
22
+ def get_ops():
23
+ cur_path = os.path.dirname(os.path.realpath(__file__))
24
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
25
+ ops = load_yaml(yaml_path)
26
+ wrap_functional = ops.get('functional')
27
+ wrap_tensor = ops.get('tensor')
28
+ wrap_torch = ops.get('torch')
29
+ return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch)
@@ -18,18 +18,17 @@
18
18
  import os
19
19
  import torch
20
20
 
21
- import yaml
22
-
23
21
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
22
  from msprobe.pytorch.common.utils import torch_device_guard
25
23
  from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
27
-
24
+ from msprobe.core.common.utils import load_yaml
25
+ from msprobe.pytorch.function_factory import npu_custom_grad_functions
28
26
 
29
27
  cur_path = os.path.dirname(os.path.realpath(__file__))
30
28
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
31
- with FileOpen(yaml_path, 'r') as f:
32
- WrapAtenOps = yaml.safe_load(f).get('aten')
29
+ ops = load_yaml(yaml_path)
30
+ wrap_aten_ops = ops.get('aten')
31
+ white_aten_ops = ops.get('white_aten_ops', [])
33
32
 
34
33
 
35
34
  aten_func = {}
@@ -38,9 +37,9 @@ for f in dir(torch.ops.aten):
38
37
 
39
38
 
40
39
  def get_aten_ops():
41
- global WrapAtenOps
40
+ global wrap_aten_ops
42
41
  _all_aten_ops = dir(torch.ops.aten)
43
- return set(WrapAtenOps) & set(_all_aten_ops)
42
+ return set(wrap_aten_ops) & set(_all_aten_ops)
44
43
 
45
44
 
46
45
  class HOOKAtenOP(object):
@@ -48,7 +47,7 @@ class HOOKAtenOP(object):
48
47
 
49
48
 
50
49
  class AtenOPTemplate(HOOKModule):
51
- def __init__(self, op, hook):
50
+ def __init__(self, op, hook, need_hook=True):
52
51
  if isinstance(op, torch._ops.OpOverloadPacket):
53
52
  op_name_ = op._qualified_op_name.split("::")[-1]
54
53
  else:
@@ -58,10 +57,21 @@ class AtenOPTemplate(HOOKModule):
58
57
  op_name_ = op_name_ + '.' + overload_name
59
58
  self.op = op
60
59
  self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
61
- super().__init__(hook)
60
+ self.need_hook = need_hook
61
+ if self.need_hook:
62
+ super().__init__(hook)
62
63
 
63
64
  @torch_device_guard
64
65
  def forward(self, *args, **kwargs):
66
+ if isinstance(self.op, str):
67
+ if self.op in npu_custom_grad_functions:
68
+ return npu_custom_grad_functions[self.op](*args, **kwargs)
69
+ if self.op in white_aten_ops:
70
+ return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
71
+ if self.op not in aten_func:
72
+ raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
73
+ f"in dir(torch.ops.aten) and support yaml.")
74
+ return aten_func[self.op](*args, **kwargs)
65
75
  return self.op(*args, **kwargs)
66
76
 
67
77
 
@@ -18,18 +18,15 @@
18
18
  import os
19
19
  from functools import wraps
20
20
  import torch.distributed as dist
21
- import yaml
22
21
 
23
22
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
23
  from msprobe.pytorch.common.utils import torch_device_guard
25
24
  from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
25
+ from msprobe.core.common.utils import load_yaml
27
26
 
28
27
 
29
28
  cur_path = os.path.dirname(os.path.realpath(__file__))
30
29
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
31
- with FileOpen(yaml_path, 'r') as f:
32
- WrapDistributedOps = yaml.safe_load(f).get('distributed')
33
30
 
34
31
 
35
32
  distributed_func = {}
@@ -38,9 +35,10 @@ for f in dir(dist):
38
35
 
39
36
 
40
37
  def get_distributed_ops():
41
- global WrapDistributedOps
42
38
  _all_distributed_ops = dir(dist)
43
- return set(WrapDistributedOps) & set(_all_distributed_ops)
39
+ yaml_data = load_yaml(yaml_path)
40
+ wrap_distributed_ops = yaml_data.get('distributed')
41
+ return set(wrap_distributed_ops) & set(_all_distributed_ops)
44
42
 
45
43
 
46
44
  class HOOKDistributedOP(object):
@@ -57,7 +55,12 @@ class DistributedOPTemplate(HOOKModule):
57
55
 
58
56
  @torch_device_guard
59
57
  def forward(self, *args, **kwargs):
60
- return distributed_func.get(self.op_name_)(*args, **kwargs)
58
+ if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
59
+ handle = distributed_func.get(self.op_name_)(*args, **kwargs)
60
+ handle.wait()
61
+ return handle
62
+ else:
63
+ return distributed_func.get(self.op_name_)(*args, **kwargs)
61
64
 
62
65
 
63
66
  def wrap_distributed_op(op_name, hook):
@@ -16,15 +16,13 @@
16
16
  """
17
17
 
18
18
  import os
19
-
20
19
  import torch
21
- import yaml
22
20
 
23
21
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
22
  from msprobe.pytorch.common.utils import torch_device_guard
25
23
  from msprobe.core.common.const import Const
26
24
  from msprobe.pytorch.common.log import logger
27
- from msprobe.core.common.file_check import FileOpen
25
+ from msprobe.core.common.utils import load_yaml
28
26
 
29
27
 
30
28
  def remove_dropout():
@@ -66,14 +64,13 @@ def remove_dropout():
66
64
 
67
65
  cur_path = os.path.dirname(os.path.realpath(__file__))
68
66
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
69
- with FileOpen(yaml_path, 'r') as f:
70
- WrapFunctionalOps = yaml.safe_load(f).get('functional')
71
67
 
72
68
 
73
69
  def get_functional_ops():
74
- global WrapFunctionalOps
70
+ yaml_data = load_yaml(yaml_path)
71
+ wrap_functional_ops = yaml_data.get('functional')
75
72
  _all_functional_ops = dir(torch.nn.functional)
76
- return set(WrapFunctionalOps) & set(_all_functional_ops)
73
+ return set(wrap_functional_ops) & set(_all_functional_ops)
77
74
 
78
75
 
79
76
  TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
@@ -17,27 +17,33 @@
17
17
 
18
18
  import os
19
19
  import torch
20
- import torch_npu
21
- import yaml
22
20
 
23
21
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
22
  from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
25
23
  from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
24
+ from msprobe.core.common.utils import load_yaml
25
+ from msprobe.pytorch.function_factory import npu_custom_functions
27
26
 
28
27
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
28
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapNpuOps = yaml.safe_load(f).get('torch_npu')
29
+
30
+
31
+ try:
32
+ import torch_npu
33
+ except ImportError:
34
+ is_gpu = True
35
+ else:
36
+ is_gpu = False
32
37
 
33
38
 
34
39
  def get_npu_ops():
35
- global WrapNpuOps
36
40
  if torch_without_guard_version:
37
41
  _npu_ops = dir(torch.ops.npu)
38
42
  else:
39
43
  _npu_ops = dir(torch_npu._C._VariableFunctionsClass)
40
- return set(WrapNpuOps) & set(_npu_ops)
44
+ yaml_data = load_yaml(yaml_path)
45
+ wrap_npu_ops = yaml_data.get('torch_npu')
46
+ return set(wrap_npu_ops) & set(_npu_ops)
41
47
 
42
48
 
43
49
  class HOOKNpuOP(object):
@@ -46,13 +52,19 @@ class HOOKNpuOP(object):
46
52
 
47
53
  class NpuOPTemplate(HOOKModule):
48
54
 
49
- def __init__(self, op_name, hook):
55
+ def __init__(self, op_name, hook, need_hook=True):
50
56
  self.op_name_ = op_name
51
57
  self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
52
- super().__init__(hook)
58
+ self.need_hook = need_hook
59
+ if need_hook:
60
+ super().__init__(hook)
53
61
 
54
62
  @torch_device_guard
55
63
  def forward(self, *args, **kwargs):
64
+ if not self.need_hook:
65
+ if self.op_name_ not in npu_custom_functions:
66
+ raise Exception(f'There is not bench function {self.op_name_}')
67
+ return npu_custom_functions[self.op_name_](*args, **kwargs)
56
68
  if torch_without_guard_version:
57
69
  return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
58
70
  else:
@@ -60,7 +72,6 @@ class NpuOPTemplate(HOOKModule):
60
72
 
61
73
 
62
74
  def wrap_npu_op(op_name, hook):
63
-
64
75
  def npu_op_template(*args, **kwargs):
65
76
  return NpuOPTemplate(op_name, hook)(*args, **kwargs)
66
77
 
@@ -18,23 +18,22 @@
18
18
  import os
19
19
 
20
20
  import torch
21
- import yaml
22
21
 
23
22
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
23
  from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
25
24
  from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
25
+ from msprobe.core.common.utils import load_yaml
26
+
27
27
 
28
28
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
29
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapTensorOps = yaml.safe_load(f).get('tensor')
32
30
 
33
31
 
34
32
  def get_tensor_ops():
35
- global WrapTensorOps
36
33
  _tensor_ops = dir(torch.Tensor)
37
- return set(WrapTensorOps) & set(_tensor_ops)
34
+ yaml_data = load_yaml(yaml_path)
35
+ wrap_tensor_ops = yaml_data.get('tensor')
36
+ return set(wrap_tensor_ops) & set(_tensor_ops)
38
37
 
39
38
 
40
39
  TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
@@ -16,25 +16,23 @@
16
16
  """
17
17
 
18
18
  import os
19
-
20
19
  import torch
21
- import yaml
22
20
 
23
21
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
22
  from msprobe.pytorch.common.utils import torch_device_guard
25
23
  from msprobe.core.common.const import Const
26
- from msprobe.core.common.file_check import FileOpen
24
+ from msprobe.core.common.utils import load_yaml
25
+
27
26
 
28
27
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
28
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapTorchOps = yaml.safe_load(f).get('torch')
32
29
 
33
30
 
34
31
  def get_torch_ops():
35
- global WrapTorchOps
36
32
  _torch_ops = []
37
- for operation in WrapTorchOps:
33
+ yaml_data = load_yaml(yaml_path)
34
+ wrap_torch_ops = yaml_data.get('torch')
35
+ for operation in wrap_torch_ops:
38
36
  if '.' in operation:
39
37
  operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1)
40
38
  operation_sub_module = getattr(torch, operation_sub_module_name)
@@ -16,24 +16,22 @@
16
16
  """
17
17
 
18
18
  import os
19
-
20
19
  import torch
21
- import yaml
22
20
 
21
+ from msprobe.core.common.const import Const
22
+ from msprobe.core.common.utils import load_yaml
23
23
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.core.common.file_check import FileOpen
25
24
  from msprobe.pytorch.common.utils import torch_device_guard
26
- from msprobe.core.common.const import Const
25
+
27
26
 
28
27
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
28
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
30
- with FileOpen(yaml_path, 'r') as f:
31
- WrapVfOps = yaml.safe_load(f).get('_VF')
32
29
 
33
30
 
34
31
  def get_vf_ops():
35
- global WrapVfOps
36
- return WrapVfOps
32
+ yaml_data = load_yaml(yaml_path)
33
+ wrap_vf_ops = yaml_data.get('_VF')
34
+ return wrap_vf_ops
37
35
 
38
36
 
39
37
  class HOOKVfOP(object):