mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,207 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from mindspore import Tensor, ops, mint
17
- from mindspore.mint.nn import functional
18
- from mindspore.common._stub_tensor import StubTensor
19
- from mindspore.communication import comm_func
20
-
21
- from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
- HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
- HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
24
- HOOKTorchDistributedOP, HOOKTorchNpuOP,
25
- get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
26
- from msprobe.core.common.utils import Const
27
- from msprobe.mindspore.common.utils import is_mindtorch
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
-
34
- def stub_method(method):
35
- def wrapped_method(*args, **kwargs):
36
- return method(*args, **kwargs)
37
- return wrapped_method
38
-
39
-
40
- class ApiRegistry:
41
- def __init__(self):
42
- self.tensor_ori_attr = {}
43
- self.stub_tensor_ori_attr = {}
44
- self.functional_ori_attr = {}
45
- self.mint_ops_ori_attr = {}
46
- self.mint_func_ops_ori_attr = {}
47
- self.distributed_ori_attr = {}
48
- self.norm_inner_ops_ori_attr = {}
49
-
50
- self.torch_ori_attr = {}
51
- self.torch_tensor_ori_attr = {}
52
- self.torch_functional_ori_attr = {}
53
- self.torch_distributed_ori_attr = {}
54
- self.torch_npu_ori_attr = {}
55
-
56
- self.tensor_hook_attr = {}
57
- self.stub_tensor_hook_attr = {}
58
- self.functional_hook_attr = {}
59
- self.mint_ops_hook_attr = {}
60
- self.mint_func_ops_hook_attr = {}
61
- self.distibuted_hook_attr = {}
62
- self.norm_inner_ops_hook_attr = {}
63
-
64
- self.torch_hook_attr = {}
65
- self.torch_tensor_hook_attr = {}
66
- self.torch_functional_hook_attr = {}
67
- self.torch_distributed_hook_attr = {}
68
- self.torch_npu_hook_attr = {}
69
-
70
- self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
71
-
72
- @staticmethod
73
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
74
- for api in api_list:
75
- if Const.SEP in api:
76
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
77
- sub_module = getattr(ori_api_group, sub_module_name)
78
- ori_api_func = getattr(sub_module, sub_op)
79
- else:
80
- ori_api_func = getattr(ori_api_group, api)
81
- if ori_api_group == StubTensor:
82
- api_ori_attr[api] = stub_method(ori_api_func)
83
- continue
84
- api_ori_attr[api] = ori_api_func
85
-
86
- @staticmethod
87
- def set_api_attr(api_group, attr_dict):
88
- for api, api_attr in attr_dict.items():
89
- if Const.SEP in api:
90
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
91
- sub_module = getattr(api_group, sub_module_name, None)
92
- if sub_module is not None:
93
- setattr(sub_module, sub_op, api_attr)
94
- else:
95
- setattr(api_group, api, api_attr)
96
-
97
- def norm_inner_op_set_hook_func(self):
98
- self.set_api_attr(ops, self.norm_inner_ops_hook_attr)
99
-
100
- def norm_inner_op_set_ori_func(self):
101
- self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
102
-
103
- def api_set_hook_func(self):
104
- if is_mindtorch():
105
- self.set_api_attr(torch, self.torch_hook_attr)
106
- self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
107
- self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
108
- self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
109
- self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr)
110
- self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
111
- else:
112
- self.set_api_attr(Tensor, self.tensor_hook_attr)
113
- self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
114
- self.set_api_attr(ops, self.functional_hook_attr)
115
- self.set_api_attr(mint, self.mint_ops_hook_attr)
116
- self.set_api_attr(functional, self.mint_func_ops_hook_attr)
117
- self.set_api_attr(comm_func, self.distibuted_hook_attr)
118
-
119
- def api_set_ori_func(self):
120
- if is_mindtorch():
121
- self.set_api_attr(torch, self.torch_ori_attr)
122
- self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
123
- self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
124
- self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
125
- self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr)
126
- self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
127
- else:
128
- self.set_api_attr(Tensor, self.tensor_ori_attr)
129
- self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
130
- self.set_api_attr(ops, self.functional_ori_attr)
131
- self.set_api_attr(mint, self.mint_ops_ori_attr)
132
- self.set_api_attr(functional, self.mint_func_ops_ori_attr)
133
- self.set_api_attr(comm_func, self.distributed_ori_attr)
134
-
135
- def initialize_hook(self, hook):
136
- setup_hooks(hook)
137
- if is_mindtorch():
138
- wrap_torch_api_name = get_wrap_torch_api_list()
139
- self.store_ori_attr(torch,
140
- wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
141
- self.store_ori_attr(torch.Tensor,
142
- wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
143
- self.store_ori_attr(torch.nn.functional,
144
- wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
145
- self.store_ori_attr(torch.distributed,
146
- wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
147
- self.store_ori_attr(torch_npu,
148
- wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
149
- for attr_name in dir(HOOKTorchOP):
150
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
151
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
152
- self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
153
- for attr_name in dir(HOOKTorchTensor):
154
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
155
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
156
- self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
157
- for attr_name in dir(HOOKTorchFunctionalOP):
158
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
159
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
160
- self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
161
- for attr_name in dir(HOOKTorchDistributedOP):
162
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
163
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
164
- self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
165
- for attr_name in dir(HOOKTorchNpuOP):
166
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
167
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
168
- self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
169
- return
170
-
171
- wrap_api_name = get_wrap_api_list()
172
- self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
173
- self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
174
- self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
175
- self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
176
- self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
177
- self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
178
- self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
179
- for attr_name in dir(HOOKTensor):
180
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
181
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
182
- self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name)
183
- for attr_name in dir(HOOKStubTensor):
184
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
185
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
186
- self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name)
187
- for attr_name in dir(HOOKFunctionalOP):
188
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
189
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
190
- self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
191
- if api_name in self.norm_inner_ops:
192
- self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
193
- for attr_name in dir(HOOKMintOP):
194
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
195
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
196
- self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name)
197
- for attr_name in dir(HOOKMintNNFunctionalOP):
198
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
199
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
200
- self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
201
- for attr_name in dir(HOOKDistributedOP):
202
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
203
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
204
- self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
205
-
206
-
207
- api_register = ApiRegistry()
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
-
18
- from mindspore import Tensor, mint, ops
19
- from mindspore.common._stub_tensor import StubTensor
20
- from mindspore.communication import comm_func
21
- from mindspore.mint.nn import functional
22
-
23
- from msprobe.core.common.const import Const
24
- from msprobe.core.common.file_utils import load_yaml
25
- from msprobe.mindspore.common.const import Const as MsConst
26
- from msprobe.mindspore.common.utils import is_mindtorch
27
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
- cur_path = os.path.dirname(os.path.realpath(__file__))
34
- yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
- torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
36
-
37
-
38
- class HOOKTensor(object):
39
- pass
40
-
41
-
42
- class HOOKStubTensor(object):
43
- pass
44
-
45
-
46
- class HOOKFunctionalOP(object):
47
- pass
48
-
49
-
50
- class HOOKMintOP(object):
51
- pass
52
-
53
-
54
- class HOOKMintNNFunctionalOP(object):
55
- pass
56
-
57
-
58
- class HOOKDistributedOP(object):
59
- pass
60
-
61
-
62
- class HOOKTorchOP(object):
63
- pass
64
-
65
-
66
- class HOOKTorchTensor(object):
67
- pass
68
-
69
-
70
- class HOOKTorchFunctionalOP(object):
71
- pass
72
-
73
-
74
- class HOOKTorchDistributedOP(object):
75
- pass
76
-
77
-
78
- class HOOKTorchNpuOP(object):
79
- pass
80
-
81
-
82
- class ApiTemplate(HOOKCell):
83
- def __init__(self, api_name, api_dict, prefix, hook):
84
- self.api_name = api_name
85
- self.api_func = api_dict[api_name]
86
- self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
87
- super().__init__(hook)
88
-
89
- @staticmethod
90
- def async_to_sync(output):
91
- # Fake handle, used to return after the CommHandle executes the wait method
92
- fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
- if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
- output[1].wait()
95
- output = (output[0], fake_handle)
96
- elif hasattr(output, "wait"):
97
- output.wait()
98
- output = fake_handle
99
- return output
100
-
101
- def construct(self, *args, **kwargs):
102
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
- return args[0] if args else kwargs.get(Const.INPUT)
104
-
105
- output = self.api_func(*args, **kwargs)
106
-
107
- if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
- if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
- output = self.async_to_sync(output)
110
- return output
111
-
112
- def forward(self, *args, **kwargs):
113
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
114
- return args[0] if args else kwargs.get(Const.INPUT)
115
- return self.api_func(*args, **kwargs)
116
-
117
-
118
- class WrapApiName:
119
- def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
120
- distributed_api_names):
121
- self.tensor_api_names = tensor_api_names
122
- self.stub_tensor_api_names = stub_tensor_api_names
123
- self.ops_api_names = ops_api_names
124
- self.mint_api_names = mint_api_names
125
- self.mint_nn_func_api_names = mint_nn_func_api_names
126
- self.distributed_api_names = distributed_api_names
127
-
128
-
129
- class WrapTorchApiName:
130
- def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
- self.torch_api_names = torch_api_names
132
- self.tensor_api_names = tensor_api_names
133
- self.functional_api_names = functional_api_names
134
- self.distributed_api_names = distributed_api_names
135
- self.npu_api_names = npu_api_names
136
-
137
-
138
- def get_wrap_api_list():
139
- api_list = load_yaml(yaml_path)
140
- tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
141
- ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
142
- mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
143
- mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
144
- distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
145
- wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
146
- set(tensor_api) & set(dir(StubTensor)),
147
- set(ops_api) & set(dir(ops)),
148
- set(mint_api) & set(dir(mint)),
149
- set(mint_nn_func_api) & set(dir(functional)),
150
- set(distributed_api) & set(dir(comm_func)))
151
- return wrap_api_name
152
-
153
-
154
- def get_wrap_torch_api_list():
155
- api_list = load_yaml(torch_yaml_path)
156
- torch_api = api_list.get("torch")
157
- tensor_api = api_list.get("tensor")
158
- functional_api = api_list.get("functional")
159
- distributed_api = api_list.get("distributed")
160
- npu_api = api_list.get("torch_npu")
161
- wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
- set(tensor_api) & set(dir(torch.Tensor)),
163
- set(functional_api) & set(dir(torch.nn.functional)),
164
- set(distributed_api) & set(dir(torch.distributed)),
165
- set(npu_api) & set(dir(torch_npu)))
166
- return wrap_api_name
167
-
168
-
169
- def wrap_api_func(api_name, api_dict, prefix, hook):
170
- def api_function(*args, **kwargs):
171
- return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
172
- return api_function
173
-
174
-
175
- def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
176
- for api_name in api_list:
177
- if callable(api_dict[api_name]):
178
- setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook))
179
-
180
-
181
- def setup_hooks(hook):
182
- if is_mindtorch():
183
- torch_wrap_api_name = get_wrap_torch_api_list()
184
- wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
- {f: getattr(torch, f) for f in dir(torch)},
186
- MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
- wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
- {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
- wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
- {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
- MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
- wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
- {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
- wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
- MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
- return
199
-
200
- wrap_api_name = get_wrap_api_list()
201
- wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
202
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
203
- wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)},
204
- MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor)
205
- wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)},
206
- MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP)
207
- wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)},
208
- MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
209
- wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
210
- MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
211
- wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
212
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)
@@ -1,140 +0,0 @@
1
- /**
2
- * Copyright 2024 Huawei Technologies Co., Ltd
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include "hook_dynamic_loader.h"
18
- #include <sys/stat.h>
19
- #include <cstdlib>
20
- #include <cstring>
21
- #include "utils/log_adapter.h"
22
-
23
- namespace {
24
-
25
- // Utility function to check if a file path is valid
26
- bool IsValidPath(const std::string &path) {
27
- struct stat fileStat;
28
- if (stat(path.c_str(), &fileStat) != 0) {
29
- MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
- return false;
31
- }
32
-
33
- if (S_ISLNK(fileStat.st_mode)) {
34
- MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
- return false;
36
- }
37
-
38
- if (!S_ISREG(fileStat.st_mode)) {
39
- MS_LOG(ERROR) << "File is not a regular file: " << path;
40
- return false;
41
- }
42
-
43
- if (path.substr(path.find_last_of(".")) != ".so") {
44
- MS_LOG(ERROR) << "File is not a .so file: " << path;
45
- return false;
46
- }
47
-
48
- return true;
49
- }
50
-
51
- } // namespace
52
-
53
- HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
- static HookDynamicLoader instance;
55
- return instance;
56
- }
57
-
58
- bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
59
- void *func = dlsym(handle, functionName.c_str());
60
- if (!func) {
61
- MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
62
- return false;
63
- }
64
- funcMap_[functionName] = func;
65
- return true;
66
- }
67
-
68
- bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
- char *realPath = realpath(libPath.c_str(), nullptr);
70
- if (!realPath) {
71
- MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
- return false;
73
- }
74
-
75
- bool isValid = IsValidPath(realPath);
76
- free(realPath); // Free memory allocated by realpath
77
- return isValid;
78
- }
79
-
80
- bool HookDynamicLoader::LoadLibrary() {
81
- const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
- if (!libPath) {
83
- MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
- return false;
85
- }
86
-
87
- std::string resolvedLibPath(libPath);
88
- if (!validateLibraryPath(resolvedLibPath)) {
89
- MS_LOG(WARNING) << "Library path validation failed.";
90
- return false;
91
- }
92
-
93
- std::lock_guard<std::mutex> lock(mutex_);
94
- if (handle_) {
95
- MS_LOG(WARNING) << "Hook library already loaded!";
96
- return false;
97
- }
98
-
99
- handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
- if (!handle_) {
101
- MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
- return false;
103
- }
104
-
105
- for (const auto &functionName : functionList_) {
106
- if (!loadFunction(handle_, functionName)) {
107
- MS_LOG(WARNING) << "Failed to load function: " << functionName;
108
- dlclose(handle_);
109
- handle_ = nullptr;
110
- return false;
111
- }
112
- }
113
-
114
- MS_LOG(INFO) << "Hook library loaded successfully.";
115
- return true;
116
- }
117
-
118
- bool HookDynamicLoader::UnloadLibrary() {
119
- std::lock_guard<std::mutex> lock(mutex_);
120
- if (!handle_) {
121
- MS_LOG(WARNING) << "Hook library hasn't been loaded.";
122
- return false;
123
- }
124
-
125
- dlclose(handle_);
126
- handle_ = nullptr;
127
- funcMap_.clear();
128
- MS_LOG(INFO) << "Library unloaded successfully.";
129
- return true;
130
- }
131
-
132
- void *HookDynamicLoader::GetHooker(const std::string &funcName) {
133
- std::lock_guard<std::mutex> lock(mutex_);
134
- auto iter = funcMap_.find(funcName);
135
- if (iter == funcMap_.end()) {
136
- MS_LOG(WARNING) << "Function not found: " << funcName;
137
- return nullptr;
138
- }
139
- return iter->second;
140
- }