mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,205 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from mindspore import Tensor, ops, mint
17
- from mindspore.mint.nn import functional
18
- from mindspore.common._stub_tensor import StubTensor
19
- from mindspore.communication import comm_func
20
-
21
- from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
- HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
- HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
24
- HOOKTorchDistributedOP, HOOKTorchNpuOP,
25
- get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
26
- from msprobe.core.common.utils import Const
27
- from msprobe.mindspore.common.utils import is_mindtorch
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
-
34
- def stub_method(method):
35
- def wrapped_method(*args, **kwargs):
36
- return method(*args, **kwargs)
37
- return wrapped_method
38
-
39
-
40
- class ApiRegistry:
41
- def __init__(self):
42
- self.tensor_ori_attr = {}
43
- self.stub_tensor_ori_attr = {}
44
- self.functional_ori_attr = {}
45
- self.mint_ops_ori_attr = {}
46
- self.mint_func_ops_ori_attr = {}
47
- self.distributed_ori_attr = {}
48
- self.norm_inner_ops_ori_attr = {}
49
-
50
- self.torch_ori_attr = {}
51
- self.torch_tensor_ori_attr = {}
52
- self.torch_functional_ori_attr = {}
53
- self.torch_distributed_ori_attr = {}
54
- self.torch_npu_ori_attr = {}
55
-
56
- self.tensor_hook_attr = {}
57
- self.stub_tensor_hook_attr = {}
58
- self.functional_hook_attr = {}
59
- self.mint_ops_hook_attr = {}
60
- self.mint_func_ops_hook_attr = {}
61
- self.distibuted_hook_attr = {}
62
- self.norm_inner_ops_hook_attr = {}
63
-
64
- self.torch_hook_attr = {}
65
- self.torch_tensor_hook_attr = {}
66
- self.torch_functional_hook_attr = {}
67
- self.torch_distributed_hook_attr = {}
68
- self.torch_npu_hook_attr = {}
69
-
70
- self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
71
-
72
- @staticmethod
73
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
74
- for api in api_list:
75
- if Const.SEP in api:
76
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
77
- sub_module = getattr(ori_api_group, sub_module_name)
78
- ori_api_func = getattr(sub_module, sub_op)
79
- else:
80
- ori_api_func = getattr(ori_api_group, api)
81
- if ori_api_group == StubTensor:
82
- api_ori_attr[api] = stub_method(ori_api_func)
83
- continue
84
- api_ori_attr[api] = ori_api_func
85
-
86
- @staticmethod
87
- def set_api_attr(api_group, attr_dict):
88
- for api, api_attr in attr_dict.items():
89
- if Const.SEP in api:
90
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
91
- sub_module = getattr(api_group, sub_module_name, None)
92
- if sub_module is not None:
93
- setattr(sub_module, sub_op, api_attr)
94
- else:
95
- setattr(api_group, api, api_attr)
96
-
97
- def norm_inner_op_set_hook_func(self):
98
- self.set_api_attr(ops, self.norm_inner_ops_hook_attr)
99
-
100
- def norm_inner_op_set_ori_func(self):
101
- self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
102
-
103
- def api_set_hook_func(self):
104
- if is_mindtorch():
105
- self.set_api_attr(torch, self.torch_hook_attr)
106
- self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
107
- self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
108
- self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
109
- self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
110
- else:
111
- self.set_api_attr(Tensor, self.tensor_hook_attr)
112
- self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
113
- self.set_api_attr(ops, self.functional_hook_attr)
114
- self.set_api_attr(mint, self.mint_ops_hook_attr)
115
- self.set_api_attr(functional, self.mint_func_ops_hook_attr)
116
- self.set_api_attr(comm_func, self.distibuted_hook_attr)
117
-
118
- def api_set_ori_func(self):
119
- if is_mindtorch():
120
- self.set_api_attr(torch, self.torch_ori_attr)
121
- self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
122
- self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
123
- self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
124
- self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
125
- else:
126
- self.set_api_attr(Tensor, self.tensor_ori_attr)
127
- self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
128
- self.set_api_attr(ops, self.functional_ori_attr)
129
- self.set_api_attr(mint, self.mint_ops_ori_attr)
130
- self.set_api_attr(functional, self.mint_func_ops_ori_attr)
131
- self.set_api_attr(comm_func, self.distributed_ori_attr)
132
-
133
- def initialize_hook(self, hook):
134
- setup_hooks(hook)
135
- if is_mindtorch():
136
- wrap_torch_api_name = get_wrap_torch_api_list()
137
- self.store_ori_attr(torch,
138
- wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
139
- self.store_ori_attr(torch.Tensor,
140
- wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
141
- self.store_ori_attr(torch.nn.functional,
142
- wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
143
- self.store_ori_attr(torch.distributed,
144
- wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
145
- self.store_ori_attr(torch_npu,
146
- wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
147
- for attr_name in dir(HOOKTorchOP):
148
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
149
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
150
- self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
151
- for attr_name in dir(HOOKTorchTensor):
152
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
153
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
154
- self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
155
- for attr_name in dir(HOOKTorchFunctionalOP):
156
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
157
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
158
- self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
159
- for attr_name in dir(HOOKTorchDistributedOP):
160
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
161
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
162
- self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
163
- for attr_name in dir(HOOKTorchNpuOP):
164
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
165
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
166
- self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
167
- return
168
-
169
- wrap_api_name = get_wrap_api_list()
170
- self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
171
- self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
172
- self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
173
- self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
174
- self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
175
- self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
176
- self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
177
- for attr_name in dir(HOOKTensor):
178
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
179
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
180
- self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name)
181
- for attr_name in dir(HOOKStubTensor):
182
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
183
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
184
- self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name)
185
- for attr_name in dir(HOOKFunctionalOP):
186
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
187
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
188
- self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
189
- if api_name in self.norm_inner_ops:
190
- self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
191
- for attr_name in dir(HOOKMintOP):
192
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
193
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
194
- self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name)
195
- for attr_name in dir(HOOKMintNNFunctionalOP):
196
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
197
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
198
- self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
199
- for attr_name in dir(HOOKDistributedOP):
200
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
201
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
202
- self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
203
-
204
-
205
- api_register = ApiRegistry()
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
-
18
- from mindspore import Tensor, mint, ops
19
- from mindspore.common._stub_tensor import StubTensor
20
- from mindspore.communication import comm_func
21
- from mindspore.mint.nn import functional
22
-
23
- from msprobe.core.common.const import Const
24
- from msprobe.core.common.file_utils import load_yaml
25
- from msprobe.mindspore.common.const import Const as MsConst
26
- from msprobe.mindspore.common.utils import is_mindtorch
27
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
- cur_path = os.path.dirname(os.path.realpath(__file__))
34
- yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
- torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
36
-
37
-
38
- class HOOKTensor(object):
39
- pass
40
-
41
-
42
- class HOOKStubTensor(object):
43
- pass
44
-
45
-
46
- class HOOKFunctionalOP(object):
47
- pass
48
-
49
-
50
- class HOOKMintOP(object):
51
- pass
52
-
53
-
54
- class HOOKMintNNFunctionalOP(object):
55
- pass
56
-
57
-
58
- class HOOKDistributedOP(object):
59
- pass
60
-
61
-
62
- class HOOKTorchOP(object):
63
- pass
64
-
65
-
66
- class HOOKTorchTensor(object):
67
- pass
68
-
69
-
70
- class HOOKTorchFunctionalOP(object):
71
- pass
72
-
73
-
74
- class HOOKTorchDistributedOP(object):
75
- pass
76
-
77
-
78
- class HOOKTorchNpuOP(object):
79
- pass
80
-
81
-
82
- class ApiTemplate(HOOKCell):
83
- def __init__(self, api_name, api_dict, prefix, hook):
84
- self.api_name = api_name
85
- self.api_func = api_dict[api_name]
86
- self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
87
- super().__init__(hook)
88
-
89
- @staticmethod
90
- def async_to_sync(output):
91
- # Fake handle, used to return after the CommHandle executes the wait method
92
- fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
- if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
- output[1].wait()
95
- output = (output[0], fake_handle)
96
- elif hasattr(output, "wait"):
97
- output.wait()
98
- output = fake_handle
99
- return output
100
-
101
- def construct(self, *args, **kwargs):
102
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
- return args[0] if args else kwargs.get(Const.INPUT)
104
-
105
- output = self.api_func(*args, **kwargs)
106
-
107
- if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
- if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
- output = self.async_to_sync(output)
110
- return output
111
-
112
- def forward(self, *args, **kwargs):
113
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
114
- return args[0] if args else kwargs.get(Const.INPUT)
115
- return self.api_func(*args, **kwargs)
116
-
117
-
118
- class WrapApiName:
119
- def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
120
- distributed_api_names):
121
- self.tensor_api_names = tensor_api_names
122
- self.stub_tensor_api_names = stub_tensor_api_names
123
- self.ops_api_names = ops_api_names
124
- self.mint_api_names = mint_api_names
125
- self.mint_nn_func_api_names = mint_nn_func_api_names
126
- self.distributed_api_names = distributed_api_names
127
-
128
-
129
- class WrapTorchApiName:
130
- def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
- self.torch_api_names = torch_api_names
132
- self.tensor_api_names = tensor_api_names
133
- self.functional_api_names = functional_api_names
134
- self.distributed_api_names = distributed_api_names
135
- self.npu_api_names = npu_api_names
136
-
137
-
138
- def get_wrap_api_list():
139
- api_list = load_yaml(yaml_path)
140
- tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
141
- ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
142
- mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
143
- mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
144
- distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
145
- wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
146
- set(tensor_api) & set(dir(StubTensor)),
147
- set(ops_api) & set(dir(ops)),
148
- set(mint_api) & set(dir(mint)),
149
- set(mint_nn_func_api) & set(dir(functional)),
150
- set(distributed_api) & set(dir(comm_func)))
151
- return wrap_api_name
152
-
153
-
154
- def get_wrap_torch_api_list():
155
- api_list = load_yaml(torch_yaml_path)
156
- torch_api = api_list.get("torch")
157
- tensor_api = api_list.get("tensor")
158
- functional_api = api_list.get("functional")
159
- distributed_api = api_list.get("distributed")
160
- npu_api = api_list.get("torch_npu")
161
- wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
- set(tensor_api) & set(dir(torch.Tensor)),
163
- set(functional_api) & set(dir(torch.nn.functional)),
164
- set(distributed_api) & set(dir(torch.distributed)),
165
- set(npu_api) & set(dir(torch_npu)))
166
- return wrap_api_name
167
-
168
-
169
- def wrap_api_func(api_name, api_dict, prefix, hook):
170
- def api_function(*args, **kwargs):
171
- return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
172
- return api_function
173
-
174
-
175
- def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
176
- for api_name in api_list:
177
- if callable(api_dict[api_name]):
178
- setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook))
179
-
180
-
181
- def setup_hooks(hook):
182
- if is_mindtorch():
183
- torch_wrap_api_name = get_wrap_torch_api_list()
184
- wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
- {f: getattr(torch, f) for f in dir(torch)},
186
- MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
- wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
- {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
- wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
- {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
- MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
- wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
- {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
- wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
- MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
- return
199
-
200
- wrap_api_name = get_wrap_api_list()
201
- wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
202
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
203
- wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)},
204
- MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor)
205
- wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)},
206
- MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP)
207
- wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)},
208
- MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
209
- wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
210
- MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
211
- wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
212
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)
@@ -1,166 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.distributed as dist
18
-
19
- from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten
20
- from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops
21
- from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops
22
- from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops
23
- from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops
24
- from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops
25
- from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops
26
- from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
27
- from msprobe.core.common.const import Const
28
-
29
- torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
30
-
31
- if not is_gpu:
32
- import torch_npu
33
- from . import wrap_npu_custom
34
- from .wrap_npu_custom import get_npu_ops
35
-
36
-
37
- class ApiRegistry:
38
- def __init__(self):
39
- self.tensor_ori_attr = {}
40
- self.torch_ori_attr = {}
41
- self.functional_ori_attr = {}
42
- self.distributed_ori_attr = {}
43
- self.npu_distributed_ori_attr = {}
44
- self.vf_ori_attr = {}
45
- self.aten_ori_attr = {}
46
- self.torch_npu_ori_attr = {}
47
-
48
- self.tensor_hook_attr = {}
49
- self.torch_hook_attr = {}
50
- self.functional_hook_attr = {}
51
- self.distributed_hook_attr = {}
52
- self.npu_distributed_hook_attr = {}
53
- self.vf_hook_attr = {}
54
- self.aten_hook_attr = {}
55
- self.torch_npu_hook_attr = {}
56
-
57
- @staticmethod
58
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
59
- for api in api_list:
60
- if '.' in api:
61
- sub_module_name, sub_op = api.rsplit('.', 1)
62
- sub_module = getattr(ori_api_group, sub_module_name)
63
- api_ori_attr[api] = getattr(sub_module, sub_op)
64
- else:
65
- api_ori_attr[api] = getattr(ori_api_group, api)
66
-
67
- @staticmethod
68
- def set_api_attr(api_group, attr_dict):
69
- for api, api_attr in attr_dict.items():
70
- if '.' in api:
71
- sub_module_name, sub_op = api.rsplit('.', 1)
72
- sub_module = getattr(api_group, sub_module_name, None)
73
- if sub_module is not None:
74
- setattr(sub_module, sub_op, api_attr)
75
- else:
76
- setattr(api_group, api, api_attr)
77
-
78
- def api_modularity(self):
79
- self.set_api_attr(torch.Tensor, self.tensor_hook_attr)
80
- self.set_api_attr(torch, self.torch_hook_attr)
81
- self.set_api_attr(torch.nn.functional, self.functional_hook_attr)
82
- self.set_api_attr(dist, self.distributed_hook_attr)
83
- self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr)
84
- if not is_gpu and not torch_without_guard_version:
85
- self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr)
86
- self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr)
87
- if torch_version_above_2:
88
- self.set_api_attr(torch.ops.aten, self.aten_hook_attr)
89
- self.set_api_attr(torch._VF, self.vf_hook_attr)
90
- if not is_gpu:
91
- self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
92
-
93
- def api_originality(self):
94
- self.set_api_attr(torch.Tensor, self.tensor_ori_attr)
95
- self.set_api_attr(torch, self.torch_ori_attr)
96
- self.set_api_attr(torch.nn.functional, self.functional_ori_attr)
97
- self.set_api_attr(dist, self.distributed_ori_attr)
98
- self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr)
99
- if not is_gpu and not torch_without_guard_version:
100
- self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr)
101
- self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr)
102
- if torch_version_above_2:
103
- self.set_api_attr(torch.ops.aten, self.aten_ori_attr)
104
- self.set_api_attr(torch._VF, self.vf_ori_attr)
105
- if not is_gpu:
106
- self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
107
-
108
- def initialize_hook(self, hook, online_run_ut=False):
109
- """
110
- initialize_hook
111
- Args:
112
- hook (_type_): initialize_hook
113
- online_run_ut (bool): default False, whether online run_ut or not.
114
- If online_run_ut is True, the hook will not wrap the aten ops.
115
- """
116
- self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr)
117
- wrap_tensor.wrap_tensor_ops_and_bind(hook)
118
- for attr_name in dir(wrap_tensor.HOOKTensor):
119
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
120
- self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name)
121
-
122
- self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr)
123
- wrap_torch.wrap_torch_ops_and_bind(hook)
124
- for attr_name in dir(wrap_torch.HOOKTorchOP):
125
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
126
- self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name)
127
-
128
- self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr)
129
- wrap_functional.wrap_functional_ops_and_bind(hook)
130
- for attr_name in dir(wrap_functional.HOOKFunctionalOP):
131
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
132
- self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name)
133
-
134
- self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr)
135
- wrap_distributed.wrap_distributed_ops_and_bind(hook)
136
- if not is_gpu and not torch_without_guard_version:
137
- self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr)
138
- for attr_name in dir(wrap_distributed.HOOKDistributedOP):
139
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
140
- self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name)
141
- if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api:
142
- self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP,
143
- attr_name)
144
-
145
- if torch_version_above_2 and not online_run_ut:
146
- self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr)
147
- wrap_aten.wrap_aten_ops_and_bind(hook)
148
- for attr_name in dir(wrap_aten.HOOKAtenOP):
149
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
150
- self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name)
151
-
152
- self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr)
153
- wrap_vf.wrap_vf_ops_and_bind(hook)
154
- for attr_name in dir(wrap_vf.HOOKVfOP):
155
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
156
- self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name)
157
-
158
- if not is_gpu:
159
- self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr)
160
- wrap_npu_custom.wrap_npu_ops_and_bind(hook)
161
- for attr_name in dir(wrap_npu_custom.HOOKNpuOP):
162
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
163
- self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name)
164
-
165
-
166
- api_register = ApiRegistry()
@@ -1,75 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from functools import wraps
18
- import torch.distributed as dist
19
-
20
- from msprobe.pytorch.hook_module.hook_module import HOOKModule
21
- from msprobe.pytorch.common.utils import torch_device_guard
22
- from msprobe.core.common.const import Const
23
- from msprobe.core.common.file_utils import load_yaml
24
-
25
-
26
- cur_path = os.path.dirname(os.path.realpath(__file__))
27
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
28
-
29
-
30
- distributed_func = {}
31
- for f in dir(dist):
32
- distributed_func[f] = getattr(dist, f)
33
-
34
-
35
- def get_distributed_ops():
36
- _all_distributed_ops = dir(dist)
37
- yaml_data = load_yaml(yaml_path)
38
- wrap_distributed_ops = yaml_data.get('distributed')
39
- return set(wrap_distributed_ops) & set(_all_distributed_ops)
40
-
41
-
42
- class HOOKDistributedOP(object):
43
- pass
44
-
45
-
46
- class DistributedOPTemplate(HOOKModule):
47
- def __init__(self, op_name, build_hook):
48
- self.op_name_ = op_name
49
- self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
50
- super().__init__(build_hook)
51
- if not self.stop_hook:
52
- self.op_is_distributed = True
53
-
54
- @torch_device_guard
55
- def forward(self, *args, **kwargs):
56
- handle = distributed_func.get(self.op_name_)(*args, **kwargs)
57
- if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
58
- if handle and hasattr(handle, 'wait'):
59
- handle.wait()
60
- return handle
61
-
62
-
63
- def wrap_distributed_op(op_name, hook):
64
- @wraps(DistributedOPTemplate)
65
- def distributed_op_template(*args, **kwargs):
66
- return DistributedOPTemplate(op_name, hook)(*args, **kwargs)
67
-
68
- distributed_op_template.__name__ = op_name
69
- return distributed_op_template
70
-
71
-
72
- def wrap_distributed_ops_and_bind(hook):
73
- _distributed_ops = get_distributed_ops()
74
- for op_name in _distributed_ops:
75
- setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook))