mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -0,0 +1,38 @@
1
+ import os
2
+ from msprobe.core.common.file_utils import load_yaml
3
+
4
+
5
+ class InplaceOpChecker:
6
+ OP_FUNCTIONAL = 'functional'
7
+ OP_TENSOR = 'tensor'
8
+ OP_TORCH = 'torch'
9
+ OP_DISTRIBUTED = 'distributed'
10
+
11
+ INPLACE_OPS_DICT = None
12
+
13
+ @classmethod
14
+ def load_ops(cls):
15
+ if cls.INPLACE_OPS_DICT is None:
16
+ cls.INPLACE_OPS_DICT = dict()
17
+ cur_path = os.path.dirname(os.path.realpath(__file__))
18
+ yaml_path = os.path.join(cur_path, "inplace_ops.yaml")
19
+ all_ops = load_yaml(yaml_path)
20
+ cls.INPLACE_OPS_DICT[cls.OP_FUNCTIONAL] = all_ops.get('inplace_functional_op')
21
+ cls.INPLACE_OPS_DICT[cls.OP_TENSOR] = all_ops.get('inplace_tensor_op')
22
+ cls.INPLACE_OPS_DICT[cls.OP_TORCH] = all_ops.get('inplace_torch_op')
23
+ cls.INPLACE_OPS_DICT[cls.OP_DISTRIBUTED] = all_ops.get('inplace_distributed_op')
24
+
25
+ @classmethod
26
+ def check(cls, api, category='distributed'):
27
+ """
28
+ 给定api和分类,检查其是否为inplace操作
29
+ """
30
+ if not cls.INPLACE_OPS_DICT:
31
+ cls.load_ops()
32
+
33
+ if category not in cls.INPLACE_OPS_DICT.keys():
34
+ return False
35
+ return api in cls.INPLACE_OPS_DICT[category]
36
+
37
+
38
+ InplaceOpChecker.load_ops()
@@ -0,0 +1,251 @@
1
+ inplace_functional_op:
2
+ - threshold_
3
+ - relu_
4
+ - hardtanh_
5
+ - elu_
6
+ - selu_
7
+ - celu_
8
+ - leaky_relu_
9
+ - rrelu_
10
+
11
+ inplace_tensor_op:
12
+ - __iadd__
13
+ - __iand__
14
+ - __idiv__
15
+ - __ifloordiv__
16
+ - __ilshift__
17
+ - __imod__
18
+ - __imul__
19
+ - __ior__
20
+ - __irshift__
21
+ - __isub__
22
+ - __ixor__
23
+ - abs_
24
+ - absolute_
25
+ - acos_
26
+ - acosh_
27
+ - add_
28
+ - addbmm_
29
+ - addcdiv_
30
+ - addcmul_
31
+ - addmm_
32
+ - addmv_
33
+ - addr_
34
+ - arccos_
35
+ - arccosh_
36
+ - arcsin_
37
+ - arcsinh_
38
+ - arctan_
39
+ - arctanh_
40
+ - asin_
41
+ - asinh_
42
+ - atan2_
43
+ - atan_
44
+ - atanh_
45
+ - baddbmm_
46
+ - bernoulli_
47
+ - bitwise_and_
48
+ - bitwise_not_
49
+ - bitwise_or_
50
+ - bitwise_xor_
51
+ - cauchy_
52
+ - ceil_
53
+ - clamp_
54
+ - clamp_max_
55
+ - clamp_min_
56
+ - clip_
57
+ - copysign_
58
+ - cos_
59
+ - cosh_
60
+ - cumprod_
61
+ - cumsum_
62
+ - deg2rad_
63
+ - digamma_
64
+ - div_
65
+ - divide_
66
+ - eq_
67
+ - erf_
68
+ - erfc_
69
+ - erfinv_
70
+ - exp2_
71
+ - exp_
72
+ - expm1_
73
+ - exponential_
74
+ - fill_
75
+ - fill_diagonal_
76
+ - fix_
77
+ - float_power_
78
+ - floor_
79
+ - floor_divide_
80
+ - fmod_
81
+ - frac_
82
+ - gcd_
83
+ - ge_
84
+ - geometric_
85
+ - greater_
86
+ - gt_
87
+ - greater_equal_
88
+ - heaviside_
89
+ - hypot_
90
+ - igamma_
91
+ - igammac_
92
+ - index_add_
93
+ - index_copy_
94
+ - index_fill_
95
+ - index_put_
96
+ - lcm_
97
+ - ldexp_
98
+ - le_
99
+ - lerp_
100
+ - less_
101
+ - less_equal_
102
+ - lgamma_
103
+ - log10_
104
+ - log1p_
105
+ - log2_
106
+ - log_
107
+ - log_normal_
108
+ - logical_and_
109
+ - logical_not_
110
+ - logical_or_
111
+ - logical_xor_
112
+ - logit_
113
+ - lt_
114
+ - map2_
115
+ - map_
116
+ - masked_fill_
117
+ - masked_scatter_
118
+ - mul_
119
+ - multiply_
120
+ - mvlgamma_
121
+ - ne_
122
+ - neg_
123
+ - negative_
124
+ - normal_
125
+ - not_equal_
126
+ - pow_
127
+ - polygamma_
128
+ - put_
129
+ - rad2deg_
130
+ - reciprocal_
131
+ - relu_
132
+ - remainder_
133
+ - renorm_
134
+ - resize_
135
+ - resize_as_
136
+ - round_
137
+ - rsqrt_
138
+ - scatter_
139
+ - scatter_add_
140
+ - sgn_
141
+ - sigmoid_
142
+ - sign_
143
+ - sin_
144
+ - sinc_
145
+ - sinh_
146
+ - sqrt_
147
+ - square_
148
+ - squeeze_
149
+ - sub_
150
+ - t_
151
+ - tan_
152
+ - tanh_
153
+ - transpose_
154
+ - tril_
155
+ - triu_
156
+ - true_divide_
157
+ - trunc_
158
+ - unsqueeze_
159
+ - xlogy_
160
+
161
+ inplace_torch_op:
162
+ - _add_relu_
163
+ - abs_
164
+ - acos_
165
+ - acosh_
166
+ - addmv_
167
+ - alpha_dropout_
168
+ - arccos_
169
+ - arccosh_
170
+ - arcsin_
171
+ - arcsinh_
172
+ - arctan_
173
+ - arctanh_
174
+ - asin_
175
+ - asinh_
176
+ - atan_
177
+ - atanh_
178
+ - ceil_
179
+ - celu_
180
+ - clamp_
181
+ - clamp_max_
182
+ - clamp_min_
183
+ - clip_
184
+ - cos_
185
+ - cosh_
186
+ - deg2rad_
187
+ - dropout_
188
+ - embedding_renorm_
189
+ - erf_
190
+ - erfc_
191
+ - exp2_
192
+ - exp_
193
+ - expm1_
194
+ - feature_alpha_dropout_
195
+ - feature_dropout_
196
+ - fill_
197
+ - fix_
198
+ - floor_
199
+ - frac_
200
+ - gcd_
201
+ - index_put_
202
+ - lcm_
203
+ - ldexp_
204
+ - log10_
205
+ - log1p_
206
+ - log2_
207
+ - log_
208
+ - logit_
209
+ - nan_to_num_
210
+ - neg_
211
+ - negative_
212
+ - rad2deg_
213
+ - reciprocal_
214
+ - relu_
215
+ - resize_as_
216
+ - round_
217
+ - rrelu_
218
+ - rsqrt_
219
+ - selu_
220
+ - sigmoid_
221
+ - sin_
222
+ - sinc_
223
+ - sinh_
224
+ - sqrt_
225
+ - square_
226
+ - tan_
227
+ - tanh_
228
+ - threshold_
229
+ - trunc_
230
+ - xlogy_
231
+
232
+ inplace_distributed_op:
233
+ - broadcast
234
+ - all_reduce
235
+ - reduce
236
+ - all_gather
237
+ - gather
238
+ - scatter
239
+ - reduce_scatter
240
+ - _reduce_scatter_base
241
+ - _all_gather_base
242
+ - send
243
+ - recv
244
+ - irecv
245
+ - isend
246
+ - all_to_all_single
247
+ - all_to_all
248
+ - all_gather_into_tensor
249
+ - reduce_scatter_tensor
250
+
251
+
@@ -1,69 +1,86 @@
1
- import os
2
- import time
3
- import sys
4
- from functools import wraps
5
- from msprobe.core.common.const import MsgConst
6
-
7
-
8
- class BaseLogger:
9
- def __init__(self):
10
- self.warning_level = "WARNING"
11
- self.error_level = "ERROR"
12
- self.info_level = "INFO"
13
- self.rank = None
14
-
15
- @staticmethod
16
- def _print_log(level, msg, end='\n'):
17
- current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
18
- pid = os.getpid()
19
- full_msg = f"{current_time} ({pid}) [{level}] {msg}"
20
- print(full_msg, end=end)
21
- sys.stdout.flush()
22
-
23
- def get_rank(self):
24
- return self.rank
25
-
26
- def filter_special_chars(func):
27
- @wraps(func)
28
- def func_level(self, msg):
29
- for char in MsgConst.SPECIAL_CHAR:
30
- msg = msg.replace(char, '_')
31
- return func(self, msg)
32
- return func_level
33
-
34
- @filter_special_chars
35
- def info(self, msg):
36
- self._print_log(self.info_level, msg)
37
-
38
- @filter_special_chars
39
- def error(self, msg):
40
- self._print_log(self.error_level, msg)
41
-
42
- @filter_special_chars
43
- def warning(self, msg):
44
- self._print_log(self.warning_level, msg)
45
-
46
- def on_rank_0(self, func):
47
- def func_rank_0(*args, **kwargs):
48
- current_rank = self.get_rank()
49
- if current_rank is None or current_rank == 0:
50
- return func(*args, **kwargs)
51
- else:
52
- return None
53
- return func_rank_0
54
-
55
- def info_on_rank_0(self, msg):
56
- return self.on_rank_0(self.info)(msg)
57
-
58
- def error_on_rank_0(self, msg):
59
- return self.on_rank_0(self.error)(msg)
60
-
61
- def warning_on_rank_0(self, msg):
62
- return self.on_rank_0(self.warning)(msg)
63
-
64
- def error_log_with_exp(self, msg, exception):
65
- self.error(msg)
66
- raise exception
67
-
68
-
69
- logger = BaseLogger()
1
+ import os
2
+ import time
3
+ import sys
4
+ from functools import wraps
5
+ from msprobe.core.common.const import MsgConst
6
+
7
+
8
+ class BaseLogger:
9
+ def __init__(self):
10
+ self.rank = None
11
+ self.level = self.get_level()
12
+
13
+ @staticmethod
14
+ def get_level():
15
+ input_level = os.environ.get(MsgConst.MSPROBE_LOG_LEVEL)
16
+ if input_level not in MsgConst.LOG_LEVEL_ENUM:
17
+ return MsgConst.LogLevel.INFO.value
18
+ else:
19
+ return int(input_level)
20
+
21
+ def get_rank(self):
22
+ return self.rank
23
+
24
+ def filter_special_chars(func):
25
+ @wraps(func)
26
+ def func_level(self, msg, **kwargs):
27
+ for char in MsgConst.SPECIAL_CHAR:
28
+ msg = msg.replace(char, '_')
29
+ return func(self, msg, **kwargs)
30
+ return func_level
31
+
32
+ @filter_special_chars
33
+ def error(self, msg):
34
+ if self.level <= MsgConst.LogLevel.ERROR.value:
35
+ self._print_log(MsgConst.LOG_LEVEL[3], msg)
36
+
37
+ @filter_special_chars
38
+ def warning(self, msg):
39
+ if self.level <= MsgConst.LogLevel.WARNING.value:
40
+ self._print_log(MsgConst.LOG_LEVEL[2], msg)
41
+
42
+ @filter_special_chars
43
+ def info(self, msg):
44
+ if self.level <= MsgConst.LogLevel.INFO.value:
45
+ self._print_log(MsgConst.LOG_LEVEL[1], msg)
46
+
47
+ @filter_special_chars
48
+ def debug(self, msg):
49
+ if self.level <= MsgConst.LogLevel.DEBUG.value:
50
+ self._print_log(MsgConst.LOG_LEVEL[0], msg)
51
+
52
+ def on_rank_0(self, func):
53
+ def func_rank_0(*args, **kwargs):
54
+ current_rank = self.get_rank()
55
+ if current_rank is None or current_rank == 0:
56
+ return func(*args, **kwargs)
57
+ else:
58
+ return None
59
+ return func_rank_0
60
+
61
+ def info_on_rank_0(self, msg):
62
+ return self.on_rank_0(self.info)(msg)
63
+
64
+ def error_on_rank_0(self, msg):
65
+ return self.on_rank_0(self.error)(msg)
66
+
67
+ def warning_on_rank_0(self, msg):
68
+ return self.on_rank_0(self.warning)(msg)
69
+
70
+ def error_log_with_exp(self, msg, exception):
71
+ self.error(msg)
72
+ raise exception
73
+
74
+ def _print_log(self, level, msg, end='\n'):
75
+ current_rank = self.get_rank()
76
+ current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
77
+ pid = os.getpid()
78
+ if current_rank is not None:
79
+ full_msg = f"{current_time} ({pid}) [rank {current_rank}] [{level}] {msg}"
80
+ else:
81
+ full_msg = f"{current_time} ({pid}) [{level}] {msg}"
82
+ print(full_msg, end=end)
83
+ sys.stdout.flush()
84
+
85
+
86
+ logger = BaseLogger()