mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,6 @@ from msprobe.core.common.log import logger
20
20
  from msprobe.core.common.const import Const
21
21
 
22
22
 
23
- MAX_RECUR_LEVEL = 100
24
-
25
-
26
23
  class Graph:
27
24
  def __init__(self, model_name, data_path='', dump_data=None):
28
25
  self.node_map = {}
@@ -67,7 +64,6 @@ class Graph:
67
64
  ancestors_b = node_b.get_ancestors()
68
65
  return node_b, ancestors_n, ancestors_b
69
66
 
70
-
71
67
  @staticmethod
72
68
  def fuzzy_match(node_n, node_b):
73
69
  if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
@@ -76,13 +72,6 @@ class Graph:
76
72
  ancestors_b = node_b.get_ancestors()
77
73
  return node_b, ancestors_n, ancestors_b
78
74
 
79
- @staticmethod
80
- def dfs(node, result):
81
- info = node.to_dict()
82
- result[node.id] = info
83
- for subnode in node.subnodes:
84
- Graph.dfs(subnode, result)
85
-
86
75
  @staticmethod
87
76
  def split_nodes_by_micro_step(nodes):
88
77
  """
@@ -24,7 +24,6 @@ class NodeOp(Enum):
24
24
  function_api = 1
25
25
  api_collection = 9
26
26
 
27
-
28
27
  @staticmethod
29
28
  def get_node_op(node_name: str):
30
29
  """
@@ -37,5 +36,5 @@ class NodeOp(Enum):
37
36
  pattern = op_patterns[index]
38
37
  if re.match(pattern, node_name):
39
38
  return op
40
- logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.")
39
+ logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.")
41
40
  return NodeOp.module
@@ -159,7 +159,7 @@ def _compare_graph_steps(input_param, args):
159
159
  bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
160
160
 
161
161
  if npu_steps != bench_steps:
162
- logger.error('The number of steps in the two runs are different. Unable to match the steps.')
162
+ logger.error('The number of steps in the two runs is different. Unable to match the steps.')
163
163
  raise CompareException(CompareException.INVALID_PATH_ERROR)
164
164
 
165
165
  for folder_step in npu_steps:
@@ -42,14 +42,6 @@ def load_data_json_file(file_path):
42
42
  return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
43
43
 
44
44
 
45
- def save_json_file(file_path, data):
46
- """
47
- 保存json文件
48
- """
49
- with FileOpen(file_path, 'w') as f:
50
- f.write(json.dumps(data, indent=4))
51
-
52
-
53
45
  def get_csv_df(stack_mode, csv_data, compare_mode):
54
46
  """
55
47
  调用acc接口写入csv
@@ -73,14 +65,6 @@ def str2float(percentage_str):
73
65
  return 0
74
66
 
75
67
 
76
- def is_integer(s):
77
- try:
78
- int(s)
79
- return True
80
- except Exception:
81
- return False
82
-
83
-
84
68
  def check_directory_content(input_path):
85
69
  """
86
70
  检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
@@ -143,14 +127,12 @@ class ToolTip:
143
127
  '当最大相对误差越接近0表示其计算的误差越小。'
144
128
  '当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象'
145
129
  )
146
- SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差'
147
130
 
148
131
 
149
132
  class GraphConst:
150
133
  CONSTRUCT_FILE = 'construct.json'
151
134
  DUMP_FILE = 'dump.json'
152
135
  STACK_FILE = 'stack.json'
153
- GRAPH_FILE = 'graph.vis'
154
136
  ERROR_KEY = 'error_key'
155
137
  SUMMARY_COMPARE = 0
156
138
  MD5_COMPARE = 1
@@ -164,35 +146,22 @@ class GraphConst:
164
146
  JSON_DATA_KEY = 'dump_data_dir'
165
147
  JSON_TASK_KEY = 'task'
166
148
  DATA_KEY = 'data'
167
- REAL_DATA_TH = 0.1
168
- MAX_RELATIVE_ERR_TH = 0.5
169
149
  ROUND_TH = 6
170
150
  JSON_INDEX_KEY = 'precision_index'
171
151
  MATCHED_DISTRIBUTED = 'matched_distributed'
172
152
  OVERFLOW_LEVEL = 'overflow_level'
173
153
  MAX_INDEX_KEY = 1
174
154
  MIN_INDEX_KEY = 0
175
- SUGGEST_KEY = 'text'
176
- TAG_NA = 'na'
177
- OUTPUT_INDEX_TWO = -2
178
- OUTPUT_INDEX_THREE = -3
179
- OUTPUT_MIN_LEN = 3
180
155
  INPUT = '.input.'
181
156
  OUTPUT = '.output.'
182
157
  STR_MAX_LEN = 50
183
- SMALL_VALUE = 1e-3
184
158
  MD5_INDEX_LIST = [CompareConst.RESULT]
185
- REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
186
- CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
187
- SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF,
188
- CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
189
- CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
190
- VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
159
+ REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
160
+ SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
191
161
  APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
192
162
  NULL = 'null'
193
163
  NONE = 'None'
194
164
  VALUE = 'value'
195
- BRACE = '{}'
196
165
  DESCRIPTION = 'description'
197
166
  COLORS = 'Colors'
198
167
  MICRO_STEPS = 'MicroSteps'
@@ -1,207 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from mindspore import Tensor, ops, mint
17
- from mindspore.mint.nn import functional
18
- from mindspore.common._stub_tensor import StubTensor
19
- from mindspore.communication import comm_func
20
-
21
- from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
22
- HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
23
- HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
24
- HOOKTorchDistributedOP, HOOKTorchNpuOP,
25
- get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
26
- from msprobe.core.common.utils import Const
27
- from msprobe.mindspore.common.utils import is_mindtorch
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
-
34
- def stub_method(method):
35
- def wrapped_method(*args, **kwargs):
36
- return method(*args, **kwargs)
37
- return wrapped_method
38
-
39
-
40
- class ApiRegistry:
41
- def __init__(self):
42
- self.tensor_ori_attr = {}
43
- self.stub_tensor_ori_attr = {}
44
- self.functional_ori_attr = {}
45
- self.mint_ops_ori_attr = {}
46
- self.mint_func_ops_ori_attr = {}
47
- self.distributed_ori_attr = {}
48
- self.norm_inner_ops_ori_attr = {}
49
-
50
- self.torch_ori_attr = {}
51
- self.torch_tensor_ori_attr = {}
52
- self.torch_functional_ori_attr = {}
53
- self.torch_distributed_ori_attr = {}
54
- self.torch_npu_ori_attr = {}
55
-
56
- self.tensor_hook_attr = {}
57
- self.stub_tensor_hook_attr = {}
58
- self.functional_hook_attr = {}
59
- self.mint_ops_hook_attr = {}
60
- self.mint_func_ops_hook_attr = {}
61
- self.distibuted_hook_attr = {}
62
- self.norm_inner_ops_hook_attr = {}
63
-
64
- self.torch_hook_attr = {}
65
- self.torch_tensor_hook_attr = {}
66
- self.torch_functional_hook_attr = {}
67
- self.torch_distributed_hook_attr = {}
68
- self.torch_npu_hook_attr = {}
69
-
70
- self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
71
-
72
- @staticmethod
73
- def store_ori_attr(ori_api_group, api_list, api_ori_attr):
74
- for api in api_list:
75
- if Const.SEP in api:
76
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
77
- sub_module = getattr(ori_api_group, sub_module_name)
78
- ori_api_func = getattr(sub_module, sub_op)
79
- else:
80
- ori_api_func = getattr(ori_api_group, api)
81
- if ori_api_group == StubTensor:
82
- api_ori_attr[api] = stub_method(ori_api_func)
83
- continue
84
- api_ori_attr[api] = ori_api_func
85
-
86
- @staticmethod
87
- def set_api_attr(api_group, attr_dict):
88
- for api, api_attr in attr_dict.items():
89
- if Const.SEP in api:
90
- sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
91
- sub_module = getattr(api_group, sub_module_name, None)
92
- if sub_module is not None:
93
- setattr(sub_module, sub_op, api_attr)
94
- else:
95
- setattr(api_group, api, api_attr)
96
-
97
- def norm_inner_op_set_hook_func(self):
98
- self.set_api_attr(ops, self.norm_inner_ops_hook_attr)
99
-
100
- def norm_inner_op_set_ori_func(self):
101
- self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
102
-
103
- def api_set_hook_func(self):
104
- if is_mindtorch():
105
- self.set_api_attr(torch, self.torch_hook_attr)
106
- self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
107
- self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
108
- self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
109
- self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr)
110
- self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
111
- else:
112
- self.set_api_attr(Tensor, self.tensor_hook_attr)
113
- self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
114
- self.set_api_attr(ops, self.functional_hook_attr)
115
- self.set_api_attr(mint, self.mint_ops_hook_attr)
116
- self.set_api_attr(functional, self.mint_func_ops_hook_attr)
117
- self.set_api_attr(comm_func, self.distibuted_hook_attr)
118
-
119
- def api_set_ori_func(self):
120
- if is_mindtorch():
121
- self.set_api_attr(torch, self.torch_ori_attr)
122
- self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
123
- self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
124
- self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
125
- self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr)
126
- self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
127
- else:
128
- self.set_api_attr(Tensor, self.tensor_ori_attr)
129
- self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
130
- self.set_api_attr(ops, self.functional_ori_attr)
131
- self.set_api_attr(mint, self.mint_ops_ori_attr)
132
- self.set_api_attr(functional, self.mint_func_ops_ori_attr)
133
- self.set_api_attr(comm_func, self.distributed_ori_attr)
134
-
135
- def initialize_hook(self, hook):
136
- setup_hooks(hook)
137
- if is_mindtorch():
138
- wrap_torch_api_name = get_wrap_torch_api_list()
139
- self.store_ori_attr(torch,
140
- wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
141
- self.store_ori_attr(torch.Tensor,
142
- wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
143
- self.store_ori_attr(torch.nn.functional,
144
- wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
145
- self.store_ori_attr(torch.distributed,
146
- wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
147
- self.store_ori_attr(torch_npu,
148
- wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
149
- for attr_name in dir(HOOKTorchOP):
150
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
151
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
152
- self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
153
- for attr_name in dir(HOOKTorchTensor):
154
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
155
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
156
- self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
157
- for attr_name in dir(HOOKTorchFunctionalOP):
158
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
159
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
160
- self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
161
- for attr_name in dir(HOOKTorchDistributedOP):
162
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
163
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
164
- self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
165
- for attr_name in dir(HOOKTorchNpuOP):
166
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
167
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
168
- self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
169
- return
170
-
171
- wrap_api_name = get_wrap_api_list()
172
- self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
173
- self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
174
- self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
175
- self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
176
- self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
177
- self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
178
- self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
179
- for attr_name in dir(HOOKTensor):
180
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
181
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
182
- self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name)
183
- for attr_name in dir(HOOKStubTensor):
184
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
185
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
186
- self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name)
187
- for attr_name in dir(HOOKFunctionalOP):
188
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
189
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
190
- self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
191
- if api_name in self.norm_inner_ops:
192
- self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
193
- for attr_name in dir(HOOKMintOP):
194
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
195
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
196
- self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name)
197
- for attr_name in dir(HOOKMintNNFunctionalOP):
198
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
199
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
200
- self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
201
- for attr_name in dir(HOOKDistributedOP):
202
- if attr_name.startswith(Const.ATTR_NAME_PREFIX):
203
- api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
204
- self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
205
-
206
-
207
- api_register = ApiRegistry()
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
-
18
- from mindspore import Tensor, mint, ops
19
- from mindspore.common._stub_tensor import StubTensor
20
- from mindspore.communication import comm_func
21
- from mindspore.mint.nn import functional
22
-
23
- from msprobe.core.common.const import Const
24
- from msprobe.core.common.file_utils import load_yaml
25
- from msprobe.mindspore.common.const import Const as MsConst
26
- from msprobe.mindspore.common.utils import is_mindtorch
27
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
28
-
29
- if is_mindtorch():
30
- import torch
31
- import torch_npu
32
-
33
- cur_path = os.path.dirname(os.path.realpath(__file__))
34
- yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
- torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
36
-
37
-
38
- class HOOKTensor(object):
39
- pass
40
-
41
-
42
- class HOOKStubTensor(object):
43
- pass
44
-
45
-
46
- class HOOKFunctionalOP(object):
47
- pass
48
-
49
-
50
- class HOOKMintOP(object):
51
- pass
52
-
53
-
54
- class HOOKMintNNFunctionalOP(object):
55
- pass
56
-
57
-
58
- class HOOKDistributedOP(object):
59
- pass
60
-
61
-
62
- class HOOKTorchOP(object):
63
- pass
64
-
65
-
66
- class HOOKTorchTensor(object):
67
- pass
68
-
69
-
70
- class HOOKTorchFunctionalOP(object):
71
- pass
72
-
73
-
74
- class HOOKTorchDistributedOP(object):
75
- pass
76
-
77
-
78
- class HOOKTorchNpuOP(object):
79
- pass
80
-
81
-
82
- class ApiTemplate(HOOKCell):
83
- def __init__(self, api_name, api_dict, prefix, hook):
84
- self.api_name = api_name
85
- self.api_func = api_dict[api_name]
86
- self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
87
- super().__init__(hook)
88
-
89
- @staticmethod
90
- def async_to_sync(output):
91
- # Fake handle, used to return after the CommHandle executes the wait method
92
- fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
- if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
- output[1].wait()
95
- output = (output[0], fake_handle)
96
- elif hasattr(output, "wait"):
97
- output.wait()
98
- output = fake_handle
99
- return output
100
-
101
- def construct(self, *args, **kwargs):
102
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
- return args[0] if args else kwargs.get(Const.INPUT)
104
-
105
- output = self.api_func(*args, **kwargs)
106
-
107
- if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
- if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
- output = self.async_to_sync(output)
110
- return output
111
-
112
- def forward(self, *args, **kwargs):
113
- if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
114
- return args[0] if args else kwargs.get(Const.INPUT)
115
- return self.api_func(*args, **kwargs)
116
-
117
-
118
- class WrapApiName:
119
- def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
120
- distributed_api_names):
121
- self.tensor_api_names = tensor_api_names
122
- self.stub_tensor_api_names = stub_tensor_api_names
123
- self.ops_api_names = ops_api_names
124
- self.mint_api_names = mint_api_names
125
- self.mint_nn_func_api_names = mint_nn_func_api_names
126
- self.distributed_api_names = distributed_api_names
127
-
128
-
129
- class WrapTorchApiName:
130
- def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
- self.torch_api_names = torch_api_names
132
- self.tensor_api_names = tensor_api_names
133
- self.functional_api_names = functional_api_names
134
- self.distributed_api_names = distributed_api_names
135
- self.npu_api_names = npu_api_names
136
-
137
-
138
- def get_wrap_api_list():
139
- api_list = load_yaml(yaml_path)
140
- tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
141
- ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
142
- mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
143
- mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
144
- distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
145
- wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
146
- set(tensor_api) & set(dir(StubTensor)),
147
- set(ops_api) & set(dir(ops)),
148
- set(mint_api) & set(dir(mint)),
149
- set(mint_nn_func_api) & set(dir(functional)),
150
- set(distributed_api) & set(dir(comm_func)))
151
- return wrap_api_name
152
-
153
-
154
- def get_wrap_torch_api_list():
155
- api_list = load_yaml(torch_yaml_path)
156
- torch_api = api_list.get("torch")
157
- tensor_api = api_list.get("tensor")
158
- functional_api = api_list.get("functional")
159
- distributed_api = api_list.get("distributed")
160
- npu_api = api_list.get("torch_npu")
161
- wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
- set(tensor_api) & set(dir(torch.Tensor)),
163
- set(functional_api) & set(dir(torch.nn.functional)),
164
- set(distributed_api) & set(dir(torch.distributed)),
165
- set(npu_api) & set(dir(torch_npu)))
166
- return wrap_api_name
167
-
168
-
169
- def wrap_api_func(api_name, api_dict, prefix, hook):
170
- def api_function(*args, **kwargs):
171
- return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
172
- return api_function
173
-
174
-
175
- def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
176
- for api_name in api_list:
177
- if callable(api_dict[api_name]):
178
- setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook))
179
-
180
-
181
- def setup_hooks(hook):
182
- if is_mindtorch():
183
- torch_wrap_api_name = get_wrap_torch_api_list()
184
- wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
- {f: getattr(torch, f) for f in dir(torch)},
186
- MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
- wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
- {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
- wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
- {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
- MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
- wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
- {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
- wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
- MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
- return
199
-
200
- wrap_api_name = get_wrap_api_list()
201
- wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
202
- MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
203
- wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)},
204
- MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor)
205
- wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)},
206
- MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP)
207
- wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)},
208
- MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
209
- wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
210
- MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
211
- wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
212
- MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)