mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import create_directory
4
20
  from msprobe.mindspore.common.const import Const as MsConst
5
21
  from msprobe.mindspore.common.const import FreeBenchmarkConst
6
- from msprobe.core.common.file_utils import create_directory
7
22
 
8
23
 
9
24
  class DebuggerConfig:
@@ -51,16 +66,4 @@ class DebuggerConfig:
51
66
  self.file_format = "npy"
52
67
  if not self.check_mode:
53
68
  self.check_mode = "all"
54
- self._check_rank()
55
- self._check_step()
56
69
  return True
57
-
58
- def _check_rank(self):
59
- for rank_id in self.rank:
60
- if not isinstance(rank_id, int) or rank_id < 0:
61
- raise ValueError(f"rank {self.rank} must be a positive integer.")
62
-
63
- def _check_step(self):
64
- for s in self.step:
65
- if not isinstance(s, int) or s < 0:
66
- raise ValueError(f"step element {s} must be a positive integer.")
@@ -1,17 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  import mindspore as ms
4
19
  from mindspore._c_expression import MSContext
5
20
 
6
- from msprobe.mindspore.service import Service
7
- from msprobe.mindspore.ms_config import parse_json_config
8
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
9
- from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
10
- from msprobe.core.common.const import Const
21
+ from msprobe.core.common.const import Const, MsgConst
11
22
  from msprobe.mindspore.common.const import Const as MsConst
12
- from msprobe.mindspore.runtime import Runtime
13
-
23
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
14
24
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
25
+ from msprobe.mindspore.ms_config import parse_json_config
26
+ from msprobe.mindspore.runtime import Runtime
27
+ from msprobe.mindspore.service import Service
28
+ from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
15
29
 
16
30
 
17
31
  class PrecisionDebugger:
@@ -65,11 +79,11 @@ class PrecisionDebugger:
65
79
  def start(cls, model=None):
66
80
  instance = cls._instance
67
81
  if not instance:
68
- raise Exception("No instance of PrecisionDebugger found.")
82
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
69
83
  if instance.task in PrecisionDebugger.task_not_need_service:
70
84
  return
71
85
 
72
- instance.config.execution_mode = instance._get_execution_mode()
86
+ instance.config.execution_mode = cls._get_execution_mode()
73
87
  if cls._need_service():
74
88
  if not instance.service:
75
89
  instance.service = Service(instance.config)
@@ -82,11 +96,21 @@ class PrecisionDebugger:
82
96
  instance.first_start = True
83
97
  Runtime.is_running = True
84
98
 
99
+ @classmethod
100
+ def forward_backward_dump_end(cls):
101
+ instance = cls._instance
102
+ if not instance:
103
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
104
+ if instance.task in PrecisionDebugger.task_not_need_service:
105
+ return
106
+ if instance.service:
107
+ instance.service.forward_backward_dump_end()
108
+
85
109
  @classmethod
86
110
  def stop(cls):
87
111
  instance = cls._instance
88
112
  if not instance:
89
- raise Exception("PrecisionDebugger instance is not created.")
113
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
90
114
  if instance.task == Const.GRAD_PROBE:
91
115
  instance.gm.stop()
92
116
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -99,7 +123,7 @@ class PrecisionDebugger:
99
123
  def step(cls):
100
124
  instance = cls._instance
101
125
  if not instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
126
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
103
127
  if instance.task in PrecisionDebugger.task_not_need_service:
104
128
  return
105
129
  if instance.service:
@@ -110,7 +134,7 @@ class PrecisionDebugger:
110
134
  def monitor(cls, opt):
111
135
  instance = cls._instance
112
136
  if not instance:
113
- raise Exception("PrecisionDebugger instance is not created.")
137
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
114
138
  if instance.task != Const.GRAD_PROBE:
115
139
  return
116
140
  instance.gm.monitor(opt)
@@ -119,7 +143,7 @@ class PrecisionDebugger:
119
143
  def _need_service(cls):
120
144
  instance = cls._instance
121
145
  if not instance:
122
- raise Exception("No instance of PrecisionDebugger found.")
146
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
123
147
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
124
148
  return False
125
149
  else:
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
- from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
4
18
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
19
+ from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
5
20
 
6
21
 
7
22
  class DumpToolFactory:
@@ -16,9 +16,10 @@
16
16
  from mindspore import Tensor, ops, mint
17
17
  from mindspore.mint.nn import functional
18
18
  from mindspore.common._stub_tensor import StubTensor
19
+ from mindspore.communication import comm_func
19
20
 
20
21
  from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
21
- HOOKMintOP, HOOKMintNNFunctionalOP,
22
+ HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
22
23
  get_wrap_api_list, setup_hooks)
23
24
  from msprobe.core.common.utils import Const
24
25
 
@@ -30,6 +31,7 @@ class ApiRegistry:
30
31
  self.functional_ori_attr = {}
31
32
  self.mint_ops_ori_attr = {}
32
33
  self.mint_func_ops_ori_attr = {}
34
+ self.distributed_ori_attr = {}
33
35
  self.norm_inner_ops_ori_attr = {}
34
36
 
35
37
  self.tensor_hook_attr = {}
@@ -37,6 +39,7 @@ class ApiRegistry:
37
39
  self.functional_hook_attr = {}
38
40
  self.mint_ops_hook_attr = {}
39
41
  self.mint_func_ops_hook_attr = {}
42
+ self.distibuted_hook_attr = {}
40
43
  self.norm_inner_ops_hook_attr = {}
41
44
 
42
45
  self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
@@ -74,6 +77,7 @@ class ApiRegistry:
74
77
  self.set_api_attr(ops, self.functional_hook_attr)
75
78
  self.set_api_attr(mint, self.mint_ops_hook_attr)
76
79
  self.set_api_attr(functional, self.mint_func_ops_hook_attr)
80
+ self.set_api_attr(comm_func, self.distibuted_hook_attr)
77
81
 
78
82
  def api_set_ori_func(self):
79
83
  self.set_api_attr(Tensor, self.tensor_ori_attr)
@@ -81,6 +85,7 @@ class ApiRegistry:
81
85
  self.set_api_attr(ops, self.functional_ori_attr)
82
86
  self.set_api_attr(mint, self.mint_ops_ori_attr)
83
87
  self.set_api_attr(functional, self.mint_func_ops_ori_attr)
88
+ self.set_api_attr(comm_func, self.distributed_ori_attr)
84
89
 
85
90
  def initialize_hook(self, hook):
86
91
  wrap_api_name = get_wrap_api_list()
@@ -89,6 +94,7 @@ class ApiRegistry:
89
94
  self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
90
95
  self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
91
96
  self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
97
+ self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
92
98
  self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
93
99
  setup_hooks(hook)
94
100
  for attr_name in dir(HOOKTensor):
@@ -113,6 +119,10 @@ class ApiRegistry:
113
119
  if attr_name.startswith(Const.ATTR_NAME_PREFIX):
114
120
  api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
115
121
  self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
122
+ for attr_name in dir(HOOKDistributedOP):
123
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
124
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
125
+ self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
116
126
 
117
127
 
118
128
  api_register = ApiRegistry()
@@ -0,0 +1,206 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+
18
+ import mindspore as ms
19
+ from mindspore.common.tensor import Tensor
20
+ from mindspore import ops
21
+
22
+ from msprobe.mindspore.common.log import logger
23
+ from msprobe.core.common.utils import Const, DumpException
24
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
25
+ ModuleBackwardInputs, ModuleBackwardOutputs
26
+
27
+
28
+ class PrimitiveHookService:
29
+ def __init__(self, service_instance):
30
+ self.primitive_counters = {}
31
+ self.service_instance = service_instance
32
+
33
+ def wrap_primitive(self, origin_func, primitive_name):
34
+ """
35
+ 包装原始的 primitive 函数,添加输入和输出的 hook 以捕获前向和反向数据。
36
+
37
+ Args:
38
+ origin_func (callable): 原始 的 primitive 函数。
39
+ primitive_name (str): 原始的 primitive 名称。
40
+
41
+ Returns:
42
+ callable: 包装后的 primitive 函数。
43
+ """
44
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
45
+ """
46
+ 创建反向 hook 函数,用于捕获梯度。
47
+
48
+ Args:
49
+ captured_grads (list): 用于保存捕获的梯度。
50
+ num_tensors (int): 张量数量。
51
+ updated_primitive_name (str): 更新后的 primitive 名称。
52
+ hook_type (str): hook 类型 (输入/输出)。
53
+
54
+ Returns:
55
+ callable: 反向 hook 函数。
56
+ """
57
+ def backward_hook(grad):
58
+
59
+ captured_grads.append(grad)
60
+ backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
61
+
62
+ try:
63
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
64
+ self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
65
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
66
+ self.service_instance.data_collector.backward_output_data_collect(
67
+ backward_primitive_name, self, os.getpid(), new_module_input_output
68
+ )
69
+ captured_grads.clear()
70
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
71
+ self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
72
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
73
+ self.service_instance.data_collector.backward_input_data_collect(
74
+ backward_primitive_name, self, os.getpid(), new_module_input_output
75
+ )
76
+ captured_grads.clear()
77
+
78
+ except Exception as exception:
79
+ logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
80
+ f"updated_primitive_name: {updated_primitive_name}")
81
+ raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
82
+
83
+ return backward_hook
84
+
85
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
86
+ """
87
+ 针对前向输入添加 hook。
88
+
89
+ Args:
90
+ args (tuple): primitive 输入参数。
91
+ captured_grads_input (list): 捕获的输入梯度。
92
+ updated_primitive_name (str): 更新后的 primitive 名称。
93
+
94
+ Returns:
95
+ list: 添加了 hook 的输入。
96
+ """
97
+ hooked_inputs = []
98
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
99
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
100
+ Const.INPUT)
101
+ for arg in args:
102
+ if isinstance(arg, Tensor):
103
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
104
+ hooked_inputs.append(arg_hooked)
105
+ else:
106
+ hooked_inputs.append(arg)
107
+ return hooked_inputs
108
+
109
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
110
+ """
111
+ 针对前向输出添加 hook。
112
+
113
+ Args:
114
+ out (Tensor/tuple): primitive 输出。
115
+ captured_grads_output (list): 捕获的输出梯度。
116
+ updated_primitive_name (str): 更新后的 primitive 名称。
117
+
118
+ Returns:
119
+ Tensor/tuple: 添加了 hook 的输出。
120
+ """
121
+ if isinstance(out, tuple):
122
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
123
+ else:
124
+ num_output_tensors = 1
125
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
126
+ updated_primitive_name, Const.OUTPUT)
127
+
128
+ if isinstance(out, Tensor):
129
+ return ops.HookBackward(output_backward_hook)(out)
130
+ elif isinstance(out, tuple):
131
+ hooked_outputs = []
132
+ for tensor in out:
133
+ if isinstance(tensor, Tensor):
134
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
135
+ else:
136
+ hooked_outputs.append(tensor)
137
+ return tuple(hooked_outputs)
138
+ return out
139
+
140
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
141
+ """
142
+ 包装后的 primitive 调用函数,添加输入和输出的 hook。
143
+
144
+ Args:
145
+ instance_self (object): primitive 的实例。
146
+ *args: primitive 输入参数。
147
+ **kwargs: primitive 关键字参数。
148
+
149
+ Returns:
150
+ Tensor/tuple: primitive 的返回值。
151
+ """
152
+ self.update_primitive_counters(primitive_name)
153
+ current_count = self.primitive_counters.get(primitive_name, 0)
154
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
155
+
156
+ if not self.service_instance.primitive_switch:
157
+ return origin_func(*args, **kwargs)
158
+
159
+ captured_grads_input, captured_grads_output = [], []
160
+
161
+ try:
162
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
163
+ except Exception as exception:
164
+ logger.error(f"This is a primitive op dump error during input hooking: {exception}, "
165
+ f"primitive_name: {primitive_name}")
166
+ raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
167
+
168
+ try:
169
+ out = origin_func(*hooked_inputs, **kwargs)
170
+ except Exception as exception:
171
+ logger.error(f"This is a primitive op dump error during function call: {exception}, "
172
+ f"primitive_name: {primitive_name}")
173
+ raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
174
+
175
+ forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
176
+ self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
177
+ if self.service_instance.data_collector:
178
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
179
+ try:
180
+ self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
181
+ os.getpid(), module_input_output)
182
+ except Exception as exception:
183
+ logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
184
+ f"primitive_name: {primitive_name}")
185
+ raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
186
+
187
+ if self.service_instance.data_collector.if_return_forward_new_output():
188
+ out = self.service_instance.data_collector.get_forward_new_output()
189
+
190
+ try:
191
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
192
+ except Exception as exception:
193
+ logger.error(f"This is a primitive op dump error during output hooking: {exception}, "
194
+ f"primitive_name: {primitive_name}")
195
+ raise DumpException(DumpException.OUTPUT_HOOK_ERROR) from exception
196
+
197
+ return out
198
+
199
+ return wrapped_primitive_call
200
+
201
+ def update_primitive_counters(self, primitive_name):
202
+ if primitive_name not in self.primitive_counters:
203
+ self.primitive_counters[primitive_name] = 0
204
+ else:
205
+ self.primitive_counters[primitive_name] += 1
206
+
@@ -185,6 +185,7 @@ ops:
185
185
  - float_power
186
186
  - fmod
187
187
  - frac
188
+ - flash_attention_score
188
189
  - gcd
189
190
  - hypot
190
191
  - igamma
@@ -876,16 +877,60 @@ mint.ops:
876
877
  - zeros
877
878
  - zeros_ex
878
879
  - zeros_like
879
-
880
- mint.nn:
881
- - Dropout
882
- - Embedding
883
- - Fold
884
- - LayerNorm
885
- - Linear
886
- - MaxPool2d
887
- - Unfold
888
- - Upsample
880
+ - inverse
881
+ - select
882
+ - item
883
+ - unsqueeze
884
+ - median
885
+ - floor
886
+ - histc
887
+ - special
888
+ - arctan2
889
+ - sign
890
+ - concat
891
+ - atanh
892
+ - greater_equal
893
+ - eye
894
+ - fix
895
+ - argmin
896
+ - asinh
897
+ - atan
898
+ - nan_to_num
899
+ - tan
900
+ - round
901
+ - cosh
902
+ - norm
903
+ - roll
904
+ - log1p
905
+ - reshape
906
+ - arccos
907
+ - outer
908
+ - arcsin
909
+ - rand_like
910
+ - acosh
911
+ - multinomial
912
+ - logical_xor
913
+ - acos
914
+ - linalg
915
+ - sinc
916
+ - arcsinh
917
+ - asin
918
+ - narrow
919
+ - arctanh
920
+ - trace
921
+ - erfc
922
+ - bernoulli
923
+ - expm1
924
+ - logaddexp
925
+ - sinh
926
+ - arccosh
927
+ - atan2
928
+ - rand
929
+ - arange
930
+ - trunc
931
+ - arctan
932
+ - swapaxes
933
+ - transpose
889
934
 
890
935
  mint.nn.functional:
891
936
  - absolute_import
@@ -920,3 +965,30 @@ mint.nn.functional:
920
965
  - softplus
921
966
  - tanh
922
967
  - unfold
968
+ - mse_loss
969
+ - adaptive_avg_pool1d
970
+ - binary_cross_entropy
971
+ - adaptive_avg_pool2d
972
+ - hardsigmoid
973
+ - selu
974
+ - softshrink
975
+ - prelu
976
+ - logsigmoid
977
+ - hardswish
978
+ - mish
979
+ - log_softmax
980
+ - hardshrink
981
+ - l1_loss
982
+ - elu
983
+
984
+ communication.comm_func:
985
+ - all_reduce
986
+ - all_gather_into_tensor
987
+ - reduce
988
+ - reduce_scatter_tensor
989
+ - all_to_all_single_with_output_shape
990
+ - all_to_all_with_output_shape
991
+ - batch_isend_irecv
992
+ - broadcast
993
+ - gather_into_tensor
994
+ - scatter_tensor
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,19 +12,18 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
 
20
- from mindspore import Tensor, ops, mint
21
- from mindspore.mint.nn import functional
18
+ from mindspore import Tensor, mint, ops
22
19
  from mindspore.common._stub_tensor import StubTensor
20
+ from mindspore.communication import comm_func
21
+ from mindspore.mint.nn import functional
23
22
 
24
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
25
23
  from msprobe.core.common.const import Const
26
- from msprobe.mindspore.common.const import Const as MsConst
27
24
  from msprobe.core.common.file_utils import load_yaml
28
-
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
29
27
 
30
28
  cur_path = os.path.dirname(os.path.realpath(__file__))
31
29
  yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
@@ -51,6 +49,10 @@ class HOOKMintNNFunctionalOP(object):
51
49
  pass
52
50
 
53
51
 
52
+ class HOOKDistributedOP(object):
53
+ pass
54
+
55
+
54
56
  class ApiTemplate(HOOKCell):
55
57
  def __init__(self, api_name, api_dict, prefix, hook):
56
58
  self.api_name = api_name
@@ -65,12 +67,14 @@ class ApiTemplate(HOOKCell):
65
67
 
66
68
 
67
69
  class WrapApiName:
68
- def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names):
70
+ def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
71
+ distributed_api_names):
69
72
  self.tensor_api_names = tensor_api_names
70
73
  self.stub_tensor_api_names = stub_tensor_api_names
71
74
  self.ops_api_names = ops_api_names
72
75
  self.mint_api_names = mint_api_names
73
76
  self.mint_nn_func_api_names = mint_nn_func_api_names
77
+ self.distributed_api_names = distributed_api_names
74
78
 
75
79
 
76
80
  def get_wrap_api_list():
@@ -79,11 +83,13 @@ def get_wrap_api_list():
79
83
  ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
80
84
  mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
81
85
  mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
86
+ distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
82
87
  wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
83
88
  set(tensor_api) & set(dir(StubTensor)),
84
89
  set(ops_api) & set(dir(ops)),
85
90
  set(mint_api) & set(dir(mint)),
86
- set(mint_nn_func_api) & set(dir(functional)))
91
+ set(mint_nn_func_api) & set(dir(functional)),
92
+ set(distributed_api) & set(dir(comm_func)))
87
93
  return wrap_api_name
88
94
 
89
95
 
@@ -111,3 +117,5 @@ def setup_hooks(hook):
111
117
  MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
112
118
  wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
113
119
  MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
120
+ wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
121
+ MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)