mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,18 +13,20 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from functools import wraps
16
+ from collections import OrderedDict
17
17
 
18
18
  import torch
19
+ from torch.utils.hooks import BackwardHook, RemovableHandle
20
+
19
21
  from msprobe.core.common.const import Const
20
22
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
23
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.common.utils import replace_last_occurrence
23
- from torch.utils.checkpoint import checkpoint as origin_checkpoint
24
- from torch.utils.checkpoint import set_checkpoint_early_stop
25
- from torch.utils.hooks import BackwardHook
24
+ from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
25
+ from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
26
26
 
27
27
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
28
+ if torch_version_above_or_equal_2:
29
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
28
30
 
29
31
 
30
32
  def checkpoint_without_early_stop(*args, **kwargs):
@@ -33,7 +35,18 @@ def checkpoint_without_early_stop(*args, **kwargs):
33
35
 
34
36
 
35
37
  def replace_checkpoint():
36
- torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
38
+ if torch_version_above_or_equal_2:
39
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
40
+
41
+
42
+ def wrap_megatron_deallocate(func):
43
+ def wrapper_func(out, deallocate_pipeline_outputs=False):
44
+ if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None:
45
+ out_clone = out.clone()
46
+ out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
47
+ return func(out_clone, deallocate_pipeline_outputs)
48
+ return func(out, deallocate_pipeline_outputs)
49
+ return wrapper_func
37
50
 
38
51
 
39
52
  class ModuleProcesser:
@@ -41,37 +54,25 @@ class ModuleProcesser:
41
54
  module_stack = []
42
55
  api_parent_node = ""
43
56
  module_node = {}
57
+ module_bw_hook_kernels = {}
58
+ module_with_backward_hook = {}
59
+ enable_module_dump = False
44
60
 
45
61
  def __init__(self, scope):
46
62
  self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
47
- BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
48
- BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
63
+ wrap_setup_input_output_hook()
49
64
  replace_checkpoint()
65
+ try:
66
+ from megatron.core.pipeline_parallel import schedules
67
+ schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
68
+ logger.info_on_rank_0("Patch megatron method success.")
69
+ except ImportError:
70
+ logger.info_on_rank_0("No megatron find.")
71
+ except Exception as e:
72
+ logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
50
73
 
51
74
  @staticmethod
52
- def clone_return_value(func):
53
- @wraps(func)
54
- def clone_return_value_func(*args, **kwargs):
55
- result = func(*args, **kwargs)
56
- return ModuleProcesser.clone_if_tensor(result)
57
-
58
- return clone_return_value_func
59
-
60
- @staticmethod
61
- def clone_if_tensor(result):
62
- if isinstance(result, torch.Tensor):
63
- return result.clone()
64
- elif type(result) is tuple:
65
- return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
66
- elif type(result) is list:
67
- return list(ModuleProcesser.clone_if_tensor(x) for x in result)
68
- elif type(result) is dict:
69
- return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
70
- else:
71
- return result
72
-
73
- @staticmethod
74
- def module_count_func(module_name):
75
+ def set_and_get_calls_number(module_name):
75
76
  if module_name not in ModuleProcesser.module_count:
76
77
  ModuleProcesser.module_count[module_name] = 0
77
78
  else:
@@ -85,13 +86,19 @@ class ModuleProcesser:
85
86
  module._is_full_backward_hook is False
86
87
 
87
88
  @staticmethod
88
- def get_modules_and_names(models):
89
+ def get_modules_and_names(models, recursive, module_names):
89
90
  modules_and_names_with_index = {}
90
91
  if isinstance(models, (list, tuple)):
92
+ if not recursive and len(module_names) != len(models):
93
+ return modules_and_names_with_index
91
94
  for index, model in enumerate(models):
92
- modules_and_names_with_index[str(index)] = model.named_modules()
95
+ modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \
96
+ [(module_names[index], model)]
93
97
  else:
94
- modules_and_names_with_index["-1"] = models.named_modules()
98
+ if not recursive and len(module_names) != 1:
99
+ return modules_and_names_with_index
100
+ modules_and_names_with_index["-1"] = models.named_modules() if recursive else \
101
+ [(module_names[0], models)]
95
102
  return modules_and_names_with_index
96
103
 
97
104
  @classmethod
@@ -100,105 +107,134 @@ class ModuleProcesser:
100
107
  cls.module_stack = []
101
108
  cls.api_parent_node = ""
102
109
  cls.module_node = {}
110
+ cls.module_bw_hook_kernels = {}
111
+ cls.enable_module_dump = False
112
+
113
+ def register_module_hook(self, models, build_hook, recursive=True, module_names=None):
114
+ if module_names is None:
115
+ module_names = []
103
116
 
104
- def register_module_hook(self, models, build_hook):
105
- logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
106
- modules_and_names_with_index = self.get_modules_and_names(models)
117
+ modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
107
118
  for index, modules_and_names in modules_and_names_with_index.items():
108
119
  model = models if index == "-1" else models[int(index)]
109
120
  for name, module in modules_and_names:
110
- if module == model:
121
+ if recursive and module == model:
111
122
  continue
123
+ if not is_torch_nn_module(module):
124
+ logger.warning(
125
+ f"The module dump does not support {type(module)} type. "
126
+ f"The data dump for this module will be skipped."
127
+ )
128
+ continue
129
+ if module.__class__.__name__ == "FullyShardedDataParallel":
130
+ continue
131
+ setattr(module, 'msprobe_hook', True)
112
132
  module_index = (index + Const.SEP) if index != "-1" else ""
113
- prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
114
- name + Const.SEP + module.__class__.__name__ + Const.SEP)
115
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
116
- BaseScope.Module_Type_Module,
117
- prefix_name
118
- )
133
+ prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
134
+ f'{module.__class__.__name__}{Const.SEP}'
135
+
136
+ forward_pre_hook = self.build_module_hook(prefix_name, build_hook)
119
137
 
120
138
  if self.has_register_backward_hook(module):
121
139
  logger.warning(
122
140
  f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
123
141
  f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
124
142
  )
143
+ ModuleProcesser.module_with_backward_hook[prefix_name] = True
144
+ register_forward_pre_hook(module, forward_pre_hook)
145
+
146
+ def build_module_hook(self, module_name, build_data_hook):
147
+ def forward_pre_hook(module, args, kwargs=None):
148
+ if kwargs is None:
149
+ kwargs = {}
150
+
151
+ if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
152
+ return (args, kwargs) if torch_version_above_or_equal_2 else args
153
+
154
+ index = ModuleProcesser.set_and_get_calls_number(module_name)
155
+ full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
156
+ full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}'
157
+
158
+ self.set_construct_info_in_pre_hook(full_forward_name)
159
+
160
+ if not hasattr(module, 'msprobe_forward_hook'):
161
+ forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict())
162
+ handle = RemovableHandle(forward_hooks_dict)
163
+ forward_hooks_dict[handle.id] = forward_hook
164
+ forward_hooks_dict.move_to_end(handle.id, last=False)
165
+ if torch_version_above_or_equal_2:
166
+ forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict())
167
+ forward_hooks_with_kwargs_dict[handle.id] = True
168
+
169
+ setattr(module, 'msprobe_forward_hook', True)
170
+
171
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
172
+
173
+ def get_backward_pre_hook(full_backward_name):
174
+ def backward_pre_hook_fn(module, grad_output):
175
+ self.set_construct_info_in_pre_hook(full_backward_name)
176
+ return backward_pre_hook_fn
177
+
178
+ def get_backward_hook(backward_data_hook, full_backward_name):
179
+ def backward_hook_fn(module, grad_input, grad_output):
180
+ new_output = backward_data_hook(module, grad_input, grad_output)
181
+ self.set_construct_info_in_hook(full_backward_name, is_forward=False)
182
+ return new_output
183
+ return backward_hook_fn
184
+
185
+ if not ModuleProcesser.module_with_backward_hook.get(module_name):
186
+ backward_pre_hook = get_backward_pre_hook(full_backward_name)
187
+ backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name)
125
188
  if torch_version_above_or_equal_2:
126
- module.register_forward_hook(forward_hook, with_kwargs=True)
189
+ bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook])
127
190
  else:
128
- if not self.has_register_backward_hook(module):
129
- module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
130
- module.register_forward_hook(forward_hook_torch_version_below_2)
131
- if not self.has_register_backward_hook(module):
132
- module.register_full_backward_hook(backward_hook)
133
-
134
- module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
135
- module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
136
- if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
137
- module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
138
- module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
139
-
140
- def node_hook(self, name_prefix, start_or_stop, **kwargs):
141
-
142
- def pre_hook(module, input, output=None):
143
- try:
144
- index = ModuleProcesser.module_count_func(name_prefix)
145
- except IndexError as e:
146
- index = None
147
- pass
148
- full_name = name_prefix + Const.SEP + str(index)
149
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
150
- module.mindstudio_reserved_name = []
151
- module.mindstudio_reserved_name.append(full_name)
152
- if self.module_stack:
153
- ModuleProcesser.module_node[full_name] = self.module_stack[-1]
191
+ bw_hook = BackwardHook(module, [backward_hook])
192
+ ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook
193
+ args = bw_hook.setup_input_hook(args)
194
+ return (args, kwargs) if torch_version_above_or_equal_2 else args
195
+
196
+ def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
197
+ if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
198
+ return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
199
+
200
+ index = ModuleProcesser.module_count.get(module_name)
201
+ full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
202
+
203
+ hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name)
204
+ hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs)
205
+ self.set_construct_info_in_hook(full_name)
206
+
207
+ if hook_result is not None:
208
+ result = hook_result
154
209
  else:
155
- ModuleProcesser.module_node[full_name] = None
210
+ result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
156
211
 
157
- ModuleProcesser.module_stack.append(full_name)
158
- if self.module_stack:
159
- ModuleProcesser.api_parent_node = self.module_stack[-1]
160
- if self.scope:
161
- self.scope.begin_module(full_name)
212
+ bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name)
213
+ if bw_hook:
214
+ result = bw_hook.setup_output_hook(result)
162
215
 
163
- def end_hook(module, input, output=None):
216
+ return result
217
+
218
+ return forward_pre_hook
219
+
220
+ def set_construct_info_in_pre_hook(self, full_name):
221
+ if self.module_stack:
222
+ ModuleProcesser.module_node[full_name] = self.module_stack[-1]
223
+ else:
224
+ ModuleProcesser.module_node[full_name] = None
225
+ ModuleProcesser.module_stack.append(full_name)
226
+ ModuleProcesser.api_parent_node = full_name
227
+ if self.scope:
228
+ self.scope.begin_module(full_name)
229
+
230
+ def set_construct_info_in_hook(self, full_name, is_forward=True):
231
+ if torch_version_above_or_equal_2 or is_forward:
164
232
  if self.module_stack:
165
233
  ModuleProcesser.module_stack.pop()
166
- if self.module_stack:
167
- ModuleProcesser.api_parent_node = self.module_stack[-1]
168
- else:
169
- ModuleProcesser.api_parent_node = None
170
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
171
- raise RuntimeError(f"module reserve name is None when pop")
172
- current_name = module.mindstudio_reserved_name.pop()
234
+ ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None
173
235
  if self.scope:
174
- self.scope.end_module(current_name)
175
-
176
- def backward_hook(module, input, output=None):
177
- try:
178
- index = ModuleProcesser.module_count_func(name_prefix)
179
- except IndexError as e:
180
- index = None
181
- pass
182
- full_name = name_prefix + Const.SEP + str(index)
183
- if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
184
- module.mindstudio_reserved_name = []
185
- module.mindstudio_reserved_name.append(full_name)
186
- forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
187
- ModuleProcesser.module_node[full_name] = replace_last_occurrence(
188
- ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
189
- ModuleProcesser.api_parent_node = None
236
+ self.scope.end_module(full_name)
237
+ else:
190
238
  if self.scope:
191
239
  self.scope.begin_module(full_name)
192
-
193
- if torch_version_above_or_equal_2:
194
- if Const.START in start_or_stop:
195
- return pre_hook
196
- else:
197
- return end_hook
198
- else:
199
- if Const.FORWARD in name_prefix and Const.START in start_or_stop:
200
- return pre_hook
201
- elif Const.BACKWARD in name_prefix:
202
- return backward_hook
203
- else:
204
- return end_hook
240
+ ModuleProcesser.api_parent_node = full_name
@@ -16,7 +16,7 @@
16
16
 
17
17
  import torch
18
18
  from msprobe.core.common.exceptions import FreeBenchmarkException
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
21
21
 
22
22
 
@@ -16,7 +16,7 @@
16
16
  import math
17
17
 
18
18
  import torch
19
- from msprobe.core.common.utils import recursion_depth_decorator
19
+ from msprobe.core.common.decorator import recursion_depth_decorator
20
20
  from msprobe.pytorch.free_benchmark import logger
21
21
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
22
22
  from msprobe.pytorch.free_benchmark.common.utils import TorchC
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
95
95
  except Exception:
96
96
  logger.warning_on_rank_0(
97
97
  f"[msprobe] Free Benchmark: For {self.api_name}, "
98
- f"when calculate maximun value, tensor is changed to float32."
98
+ f"when calculating the maximum value, the tensor is changed to float32."
99
99
  )
100
100
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
101
101
  if max_val < abs_tol:
102
102
  logger.warning_on_rank_0(
103
103
  f"[msprobe] Free Benchmark: For {self.api_name}, "
104
- f"Maximun value is less than the minimun threshold. Cancel add noise."
104
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
105
105
  )
106
106
  return False
107
107
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
20
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
100
100
  except Exception:
101
101
  logger.warning_on_rank_0(
102
102
  f"[msprobe] Free Benchmark: For {self.api_name}, "
103
- f"when calculate maximun value, tensor is changed to float32."
103
+ f"when calculate the maximum value, the tensor is changed to float32."
104
104
  )
105
105
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
106
106
  if max_val < abs_tol:
107
107
  logger.warning_on_rank_0(
108
108
  f"[msprobe] Free Benchmark: For {self.api_name}, "
109
- f"Maximun value is less than the minimun threshold. Cancel add noise."
109
+ f"maximum value is less than the minimum threshold. Cancel adding noise."
110
110
  )
111
111
  return False
112
112
  return True
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
- from msprobe.core.common.utils import recursion_depth_decorator
17
+ from msprobe.core.common.decorator import recursion_depth_decorator
18
18
  from msprobe.pytorch.free_benchmark import logger
19
19
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
20
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -15,7 +15,7 @@
15
15
 
16
16
  import torch
17
17
  from msprobe.core.common.const import Const
18
- from msprobe.core.common.utils import recursion_depth_decorator
18
+ from msprobe.core.common.decorator import recursion_depth_decorator
19
19
  from msprobe.pytorch.free_benchmark import logger
20
20
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
21
21
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
49
49
  except Exception as e:
50
50
  logger.warning_on_rank_0(
51
51
  f"[msprobe] Free Benchmark: For {self.params.api_name}, "
52
- f"when campare the result exception raise {e}"
52
+ f"when comparing the results, an exception is raised: {e}"
53
53
  )
54
54
  return data_params.original_result
@@ -70,7 +70,7 @@ class Register(dict):
70
70
 
71
71
  def add_register_item(key, value):
72
72
  if key in self._dict:
73
- logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
73
+ logger.warning(f"{value.__name__} has been registered before, so we will override it.")
74
74
  self[key] = value
75
75
  return value
76
76
 
@@ -46,7 +46,7 @@ class GradientMonitor:
46
46
  if not os.path.exists(self._output_path):
47
47
  create_directory(self._output_path)
48
48
  else:
49
- logger.warning(f"the file in {self._output_path} will be recoverd")
49
+ logger.warning(f"the file in {self._output_path} will be deleted")
50
50
  self._step = -1
51
51
  self._param2name = defaultdict(str)
52
52
 
@@ -97,7 +97,7 @@ class GradientMonitor:
97
97
  create_directory(output_dirpath)
98
98
  output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
99
99
  if os.path.exists(output_path):
100
- logger.warning(f"{output_path} will be recoverd")
100
+ logger.warning(f"{output_path} will be deleted")
101
101
  remove_path(output_path)
102
102
  header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
103
103
  output_lines.insert(0, header_result)
@@ -17,6 +17,7 @@ from abc import ABC, abstractmethod
17
17
  from collections import namedtuple
18
18
  import hashlib
19
19
  from functools import wraps
20
+ import zlib
20
21
  import torch
21
22
  from msprobe.core.grad_probe.constant import GradConst
22
23
 
@@ -74,8 +75,8 @@ class CsvMd5(CsvItem):
74
75
  def generate_csv_content(csv_content_input):
75
76
  grad = csv_content_input.grad
76
77
  tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
77
- md5_hash = hashlib.md5(tensor_bytes)
78
- return [md5_hash.hexdigest()]
78
+ md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
79
+ return [md5_hash]
79
80
 
80
81
 
81
82
  @register_csv_item(GradConst.DISTRIBUTION)
@@ -0,0 +1,155 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import os
18
+ import inspect
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.data_dump.api_registry import ApiRegistry
25
+ from msprobe.pytorch.common.log import logger
26
+ from msprobe.pytorch.common.utils import (
27
+ torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
28
+ )
29
+ from msprobe.pytorch.function_factory import npu_custom_functions
30
+ from msprobe.pytorch.hook_module.hook_module import HOOKModule
31
+ from msprobe.pytorch.hook_module.utils import dynamic_import_op
32
+ from msprobe.core.common.file_utils import load_yaml
33
+
34
+ try:
35
+ import mindspeed.ops
36
+ except ImportError:
37
+ mindspeed_enable = False
38
+ else:
39
+ mindspeed_enable = True
40
+
41
+
42
+ torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
43
+
44
+ _inner_used_api = {}
45
+ _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
46
+ _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
47
+
48
+ _api_types = {
49
+ Const.PT_FRAMEWORK: {
50
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
51
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
52
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
53
+ Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
54
+ Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
55
+ }
56
+ }
57
+ if not is_gpu:
58
+ import torch_npu
59
+ if torch_without_guard_version:
60
+ _api_types.get(Const.PT_FRAMEWORK).update(
61
+ {
62
+ Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
63
+ }
64
+ )
65
+ else:
66
+ _api_types.get(Const.PT_FRAMEWORK).update(
67
+ {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
68
+ )
69
+ _api_types.get(Const.PT_FRAMEWORK).update(
70
+ {
71
+ Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
72
+ torch_npu.distributed.distributed_c10d))
73
+ }
74
+ )
75
+ if mindspeed_enable:
76
+ _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
77
+ mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
78
+ mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
79
+ dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
80
+
81
+
82
+ @parameter_adapter
83
+ def tensor_module_forward(module, *args, **kwargs):
84
+ return module.api_func(*args, **kwargs)
85
+
86
+
87
+ def dist_module_forward(module, *args, **kwargs):
88
+ handle = module.api_func(*args, **kwargs)
89
+ try:
90
+ bound = inspect.signature(module.api_func).bind(*args, **kwargs)
91
+ bound.apply_defaults()
92
+ use_asyn_op_flag = bound.arguments.get("asyn_op", False)
93
+ except Exception as e:
94
+ use_asyn_op_flag = False
95
+ logger.warning(f"fail to get dist api's func signature because {e}, no wait")
96
+
97
+ if use_asyn_op_flag or module.api_name in ["isend", "irecv"]:
98
+ if handle and hasattr(handle, 'wait'):
99
+ handle.wait()
100
+ if module.api_name == "batch_isend_irecv":
101
+ if isinstance(handle, list):
102
+ for req in handle:
103
+ req.wait()
104
+ return handle
105
+
106
+
107
+ def npu_module_forward(module, *args, **kwargs):
108
+ if not module.need_hook:
109
+ if module.api_name not in npu_custom_functions:
110
+ raise Exception(f'There is not bench function {module.api_name}')
111
+ if module.device == Const.CUDA_LOWERCASE:
112
+ module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
113
+ if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
114
+ return npu_custom_functions[module.api_name](*args, **kwargs)
115
+ return module.api_func(*args, **kwargs)
116
+
117
+
118
+ forward_methods = {
119
+ "Tensor": tensor_module_forward,
120
+ "Distributed": dist_module_forward,
121
+ "NPU": npu_module_forward
122
+ }
123
+
124
+
125
+ class ApiTemplate(HOOKModule):
126
+ def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
127
+ self.api_name = api_name
128
+ self.api_func = api_func
129
+ self.prefix = prefix
130
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
131
+ self.need_hook = need_hook
132
+ self.device = device
133
+ if self.need_hook:
134
+ super().__init__(hook_build_func)
135
+ if prefix == Const.DIST_API_TYPE_PREFIX:
136
+ self.op_is_distributed = True
137
+
138
+ @torch_device_guard
139
+ def forward(self, *args, **kwargs):
140
+ exec_func = forward_methods.get(self.prefix)
141
+ exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
142
+ return exec_func(*args, **kwargs)
143
+
144
+
145
+ api_register = None
146
+
147
+
148
+ def get_api_register(return_new=False):
149
+ if return_new:
150
+ return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
151
+
152
+ global api_register
153
+ if api_register is None:
154
+ api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
155
+ return api_register