mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,104 +1,128 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- import mindspore as ms
17
- from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
18
- HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
19
- from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
20
- from msprobe.core.common.utils import Const
21
-
22
-
23
- class ApiRegistry:
24
- def __init__(self):
25
- self.tensor_ori_attr = {}
26
- self.functional_ori_attr = {}
27
- self.mint_ops_ori_attr = {}
28
- self.mint_func_ops_ori_attr = {}
29
- self.norm_inner_ops_ori_attr = {}
30
-
31
- self.tensor_hook_attr = {}
32
- self.functional_hook_attr = {}
33
- self.mint_ops_hook_attr = {}
34
- self.mint_func_ops_hook_attr = {}
35
- self.norm_inner_ops_hook_attr = {}
36
-
37
- self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
38
-
39
- @staticmethod
40
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
41
- for api in api_list:
42
- if Const.SEP in api:
43
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
44
- sub_module = getattr(ori_api_group, sub_module_name)
45
- api_ori_attr[api] = getattr(sub_module, sub_op)
46
- else:
47
- api_ori_attr[api] = getattr(ori_api_group, api)
48
-
49
- @staticmethod
50
- def set_api_attr(api_group, attr_dict):
51
- for api, api_attr in attr_dict.items():
52
- if Const.SEP in api:
53
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
54
- sub_module = getattr(api_group, sub_module_name, None)
55
- if sub_module is not None:
56
- setattr(sub_module, sub_op, api_attr)
57
- else:
58
- setattr(api_group, api, api_attr)
59
-
60
- def norm_inner_op_set_hook_func(self):
61
- self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr)
62
-
63
- def norm_inner_op_set_ori_func(self):
64
- self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr)
65
-
66
- def api_set_hook_func(self):
67
- self.set_api_attr(ms.Tensor, self.tensor_hook_attr)
68
- self.set_api_attr(ms.ops, self.functional_hook_attr)
69
- self.set_api_attr(ms.mint, self.mint_ops_hook_attr)
70
- self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr)
71
-
72
- def api_set_ori_func(self):
73
- self.set_api_attr(ms.Tensor, self.tensor_ori_attr)
74
- self.set_api_attr(ms.ops, self.functional_ori_attr)
75
- self.set_api_attr(ms.mint, self.mint_ops_ori_attr)
76
- self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr)
77
-
78
- def initialize_hook(self, hook):
79
- self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr)
80
- wrap_tensor_ops_and_bind(hook)
81
- for attr_name in dir(HOOKTensor):
82
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
83
- self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name)
84
-
85
- functional_ops, mint_ops, mint_func_ops = get_functional_ops()
86
- self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
87
- self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr)
88
- self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr)
89
- self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr)
90
- setup_hooks(hook)
91
- for attr_name in dir(HOOKFunctionalOP):
92
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
93
- self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
94
- if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops:
95
- self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
96
- for attr_name in dir(HOOKMintOP):
97
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
98
- self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name)
99
- for attr_name in dir(HOOKMintNNFunctionalOP):
100
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
101
- self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
102
-
103
-
104
- api_register = ApiRegistry()
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ from mindspore import Tensor, ops, mint
17
+ from mindspore.mint.nn import functional
18
+ from mindspore.common._stub_tensor import StubTensor
19
+ from mindspore.communication import comm_func
20
+
21
+ from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
+ HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
+ get_wrap_api_list, setup_hooks)
24
+ from msprobe.core.common.utils import Const
25
+
26
+
27
+ class ApiRegistry:
28
+ def __init__(self):
29
+ self.tensor_ori_attr = {}
30
+ self.stub_tensor_ori_attr = {}
31
+ self.functional_ori_attr = {}
32
+ self.mint_ops_ori_attr = {}
33
+ self.mint_func_ops_ori_attr = {}
34
+ self.distributed_ori_attr = {}
35
+ self.norm_inner_ops_ori_attr = {}
36
+
37
+ self.tensor_hook_attr = {}
38
+ self.stub_tensor_hook_attr = {}
39
+ self.functional_hook_attr = {}
40
+ self.mint_ops_hook_attr = {}
41
+ self.mint_func_ops_hook_attr = {}
42
+ self.distibuted_hook_attr = {}
43
+ self.norm_inner_ops_hook_attr = {}
44
+
45
+ self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
46
+
47
+ @staticmethod
48
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
49
+ for api in api_list:
50
+ if Const.SEP in api:
51
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
52
+ sub_module = getattr(ori_api_group, sub_module_name)
53
+ api_ori_attr[api] = getattr(sub_module, sub_op)
54
+ else:
55
+ api_ori_attr[api] = getattr(ori_api_group, api)
56
+
57
+ @staticmethod
58
+ def set_api_attr(api_group, attr_dict):
59
+ for api, api_attr in attr_dict.items():
60
+ if Const.SEP in api:
61
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
62
+ sub_module = getattr(api_group, sub_module_name, None)
63
+ if sub_module is not None:
64
+ setattr(sub_module, sub_op, api_attr)
65
+ else:
66
+ setattr(api_group, api, api_attr)
67
+
68
+ def norm_inner_op_set_hook_func(self):
69
+ self.set_api_attr(ops, self.norm_inner_ops_hook_attr)
70
+
71
+ def norm_inner_op_set_ori_func(self):
72
+ self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
73
+
74
+ def api_set_hook_func(self):
75
+ self.set_api_attr(Tensor, self.tensor_hook_attr)
76
+ self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
77
+ self.set_api_attr(ops, self.functional_hook_attr)
78
+ self.set_api_attr(mint, self.mint_ops_hook_attr)
79
+ self.set_api_attr(functional, self.mint_func_ops_hook_attr)
80
+ self.set_api_attr(comm_func, self.distibuted_hook_attr)
81
+
82
+ def api_set_ori_func(self):
83
+ self.set_api_attr(Tensor, self.tensor_ori_attr)
84
+ self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
85
+ self.set_api_attr(ops, self.functional_ori_attr)
86
+ self.set_api_attr(mint, self.mint_ops_ori_attr)
87
+ self.set_api_attr(functional, self.mint_func_ops_ori_attr)
88
+ self.set_api_attr(comm_func, self.distributed_ori_attr)
89
+
90
+ def initialize_hook(self, hook):
91
+ wrap_api_name = get_wrap_api_list()
92
+ self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
93
+ self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
94
+ self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
95
+ self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
96
+ self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
97
+ self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
98
+ self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
99
+ setup_hooks(hook)
100
+ for attr_name in dir(HOOKTensor):
101
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
102
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
103
+ self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name)
104
+ for attr_name in dir(HOOKStubTensor):
105
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
106
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
107
+ self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name)
108
+ for attr_name in dir(HOOKFunctionalOP):
109
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
110
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
111
+ self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
112
+ if api_name in self.norm_inner_ops:
113
+ self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
114
+ for attr_name in dir(HOOKMintOP):
115
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
116
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
117
+ self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name)
118
+ for attr_name in dir(HOOKMintNNFunctionalOP):
119
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
120
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
121
+ self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
122
+ for attr_name in dir(HOOKDistributedOP):
123
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
124
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
125
+ self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
126
+
127
+
128
+ api_register = ApiRegistry()
@@ -1,53 +1,55 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- from collections import defaultdict
16
-
17
- from mindspore import nn
18
- from msprobe.core.common.const import Const
19
-
20
-
21
- class HOOKCell(nn.Cell):
22
- cell_count = defaultdict(int)
23
- g_stop_hook = False
24
-
25
- def __init__(self, build_hook) -> None:
26
- super(HOOKCell, self).__init__()
27
- self.changed_status = False
28
- self.input_kwargs = {}
29
- self.prefix = ""
30
- if not HOOKCell.g_stop_hook:
31
- HOOKCell.g_stop_hook = True
32
- self.changed_status = True
33
- if hasattr(self, "prefix_op_name_"):
34
- self.prefix = self.prefix_op_name_
35
-
36
- HOOKCell.cell_count[self.prefix] += 1
37
- self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
38
- forward_hook, backward_hook = build_hook(self.prefix)
39
- self.register_forward_hook(forward_hook)
40
- self.register_backward_hook(backward_hook)
41
-
42
- # 重载call,加全局标志。
43
- def __call__(self, *args, **kwargs):
44
- try:
45
- self.input_kwargs = kwargs
46
- out = super(HOOKCell, self).__call__(*args, **kwargs)
47
- except Exception as e:
48
- raise e
49
- finally:
50
- if self.changed_status:
51
- self.changed_status = False
52
- HOOKCell.g_stop_hook = False
53
- return out
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ from collections import defaultdict
17
+
18
+ from mindspore import nn
19
+
20
+ from msprobe.core.common.const import Const
21
+
22
+
23
+ class HOOKCell(nn.Cell):
24
+ cell_count = defaultdict(int)
25
+ g_stop_hook = False
26
+
27
+ def __init__(self, build_hook) -> None:
28
+ super(HOOKCell, self).__init__()
29
+ self.changed_status = False
30
+ self.input_kwargs = {}
31
+ self.prefix = ""
32
+ if not HOOKCell.g_stop_hook:
33
+ HOOKCell.g_stop_hook = True
34
+ self.changed_status = True
35
+ if hasattr(self, "prefix_api_name"):
36
+ self.prefix = self.prefix_api_name
37
+
38
+ HOOKCell.cell_count[self.prefix] += 1
39
+ self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
40
+ forward_hook, backward_hook = build_hook(self.prefix)
41
+ self.register_forward_hook(forward_hook)
42
+ self.register_backward_hook(backward_hook)
43
+
44
+ # 重载call,加全局标志。
45
+ def __call__(self, *args, **kwargs):
46
+ try:
47
+ self.input_kwargs = kwargs
48
+ out = super(HOOKCell, self).__call__(*args, **kwargs)
49
+ except Exception as e:
50
+ raise e
51
+ finally:
52
+ if self.changed_status:
53
+ self.changed_status = False
54
+ HOOKCell.g_stop_hook = False
55
+ return out
@@ -0,0 +1,206 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+
18
+ import mindspore as ms
19
+ from mindspore.common.tensor import Tensor
20
+ from mindspore import ops
21
+
22
+ from msprobe.mindspore.common.log import logger
23
+ from msprobe.core.common.utils import Const, DumpException
24
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
25
+ ModuleBackwardInputs, ModuleBackwardOutputs
26
+
27
+
28
+ class PrimitiveHookService:
29
+ def __init__(self, service_instance):
30
+ self.primitive_counters = {}
31
+ self.service_instance = service_instance
32
+
33
+ def wrap_primitive(self, origin_func, primitive_name):
34
+ """
35
+ 包装原始的 primitive 函数,添加输入和输出的 hook 以捕获前向和反向数据。
36
+
37
+ Args:
38
+ origin_func (callable): 原始 的 primitive 函数。
39
+ primitive_name (str): 原始的 primitive 名称。
40
+
41
+ Returns:
42
+ callable: 包装后的 primitive 函数。
43
+ """
44
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
45
+ """
46
+ 创建反向 hook 函数,用于捕获梯度。
47
+
48
+ Args:
49
+ captured_grads (list): 用于保存捕获的梯度。
50
+ num_tensors (int): 张量数量。
51
+ updated_primitive_name (str): 更新后的 primitive 名称。
52
+ hook_type (str): hook 类型 (输入/输出)。
53
+
54
+ Returns:
55
+ callable: 反向 hook 函数。
56
+ """
57
+ def backward_hook(grad):
58
+
59
+ captured_grads.append(grad)
60
+ backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
+
62
+ try:
63
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
64
+ self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
65
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
66
+ self.service_instance.data_collector.backward_output_data_collect(
67
+ backward_primitive_name, self, os.getpid(), new_module_input_output
68
+ )
69
+ captured_grads.clear()
70
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
71
+ self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
72
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
73
+ self.service_instance.data_collector.backward_input_data_collect(
74
+ backward_primitive_name, self, os.getpid(), new_module_input_output
75
+ )
76
+ captured_grads.clear()
77
+
78
+ except Exception as exception:
79
+ logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
80
+ f"updated_primitive_name: {updated_primitive_name}")
81
+ raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
82
+
83
+ return backward_hook
84
+
85
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
86
+ """
87
+ 针对前向输入添加 hook。
88
+
89
+ Args:
90
+ args (tuple): primitive 输入参数。
91
+ captured_grads_input (list): 捕获的输入梯度。
92
+ updated_primitive_name (str): 更新后的 primitive 名称。
93
+
94
+ Returns:
95
+ list: 添加了 hook 的输入。
96
+ """
97
+ hooked_inputs = []
98
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
99
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
100
+ Const.INPUT)
101
+ for arg in args:
102
+ if isinstance(arg, Tensor):
103
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
104
+ hooked_inputs.append(arg_hooked)
105
+ else:
106
+ hooked_inputs.append(arg)
107
+ return hooked_inputs
108
+
109
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
110
+ """
111
+ 针对前向输出添加 hook。
112
+
113
+ Args:
114
+ out (Tensor/tuple): primitive 输出。
115
+ captured_grads_output (list): 捕获的输出梯度。
116
+ updated_primitive_name (str): 更新后的 primitive 名称。
117
+
118
+ Returns:
119
+ Tensor/tuple: 添加了 hook 的输出。
120
+ """
121
+ if isinstance(out, tuple):
122
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
123
+ else:
124
+ num_output_tensors = 1
125
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
126
+ updated_primitive_name, Const.OUTPUT)
127
+
128
+ if isinstance(out, Tensor):
129
+ return ops.HookBackward(output_backward_hook)(out)
130
+ elif isinstance(out, tuple):
131
+ hooked_outputs = []
132
+ for tensor in out:
133
+ if isinstance(tensor, Tensor):
134
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
135
+ else:
136
+ hooked_outputs.append(tensor)
137
+ return tuple(hooked_outputs)
138
+ return out
139
+
140
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
141
+ """
142
+ 包装后的 primitive 调用函数,添加输入和输出的 hook。
143
+
144
+ Args:
145
+ instance_self (object): primitive 的实例。
146
+ *args: primitive 输入参数。
147
+ **kwargs: primitive 关键字参数。
148
+
149
+ Returns:
150
+ Tensor/tuple: primitive 的返回值。
151
+ """
152
+ self.update_primitive_counters(primitive_name)
153
+ current_count = self.primitive_counters.get(primitive_name, 0)
154
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
155
+
156
+ if not self.service_instance.primitive_switch:
157
+ return origin_func(*args, **kwargs)
158
+
159
+ captured_grads_input, captured_grads_output = [], []
160
+
161
+ try:
162
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
163
+ except Exception as exception:
164
+ logger.error(f"This is a primitive op dump error during input hooking: {exception}, "
165
+ f"primitive_name: {primitive_name}")
166
+ raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
167
+
168
+ try:
169
+ out = origin_func(*hooked_inputs, **kwargs)
170
+ except Exception as exception:
171
+ logger.error(f"This is a primitive op dump error during function call: {exception}, "
172
+ f"primitive_name: {primitive_name}")
173
+ raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
174
+
175
+ forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
176
+ self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
177
+ if self.service_instance.data_collector:
178
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
179
+ try:
180
+ self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
181
+ os.getpid(), module_input_output)
182
+ except Exception as exception:
183
+ logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
184
+ f"primitive_name: {primitive_name}")
185
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
186
+
187
+ if self.service_instance.data_collector.if_return_forward_new_output():
188
+ out = self.service_instance.data_collector.get_forward_new_output()
189
+
190
+ try:
191
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
192
+ except Exception as exception:
193
+ logger.error(f"This is a primitive op dump error during output hooking: {exception}, "
194
+ f"primitive_name: {primitive_name}")
195
+ raise DumpException(DumpException.OUTPUT_HOOK_ERROR) from exception
196
+
197
+ return out
198
+
199
+ return wrapped_primitive_call
200
+
201
+ def update_primitive_counters(self, primitive_name):
202
+ if primitive_name not in self.primitive_counters:
203
+ self.primitive_counters[primitive_name] = 0
204
+ else:
205
+ self.primitive_counters[primitive_name] += 1
206
+