mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,23 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from abc import ABC, abstractmethod
3
18
  from typing import Any, Optional, Tuple
4
- import numpy as np
5
19
 
20
+ import numpy as np
6
21
  import torch
7
22
  from msprobe.core.common.const import Const
8
23
  from msprobe.pytorch.free_benchmark import logger
@@ -35,7 +50,9 @@ class FuzzHandler(ABC):
35
50
  origin_ouput = origin_ouput.values
36
51
  perturbed_output = perturbed_output.values
37
52
  if hasattr(perturbed_output, "dtype"):
38
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
53
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
54
+ perturbed_output.dtype, FuzzThreshold.F32_THD
55
+ )
39
56
  else:
40
57
  abs_tol = FuzzThreshold.F32_THD
41
58
  return (
@@ -53,16 +70,22 @@ class FuzzHandler(ABC):
53
70
  :return origin_output_chunks: 切块后原始输出列表
54
71
  :return perturbed_output_chunks: 切块后扰动后输出列表
55
72
  """
56
- single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
73
+ single_output_mem = (
74
+ origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
75
+ )
57
76
  if single_output_mem == 0 or origin_output.ndim == 0:
58
77
  return [origin_output], [perturbed_output]
59
78
  # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
79
  chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
- chunks = 2 ** chunks_exp
80
+ chunks = 2**chunks_exp
62
81
  chunks = max(chunks, 1)
63
82
  chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
- origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
- perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
83
+ origin_output_chunks = TorchC.tensor_split(
84
+ TorchC.reshape(origin_output, (-1,)), chunks
85
+ )
86
+ perturbed_output_chunks = TorchC.tensor_split(
87
+ TorchC.reshape(perturbed_output, (-1,)), chunks
88
+ )
66
89
  return origin_output_chunks, perturbed_output_chunks
67
90
 
68
91
  @staticmethod
@@ -80,14 +103,16 @@ class FuzzHandler(ABC):
80
103
  pass
81
104
 
82
105
  def get_ratio_from_specific_norm(
83
- self, origin_output, perturbed_output, norm_type, abs_tol
106
+ self, origin_output, perturbed_output, norm_type, abs_tol
84
107
  ):
85
108
  if norm_type == NormType.ENDLESS_NORM:
86
109
  return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
110
  return ThresholdConfig.COMP_CONSISTENT
88
111
 
89
112
  def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
- origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
113
+ origin_output_chunks, perturbed_output_chunks = (
114
+ self.tensor_split_for_error_calculate(origin_output, perturbed_output)
115
+ )
91
116
  norm1 = -np.inf
92
117
  norm2 = -np.inf
93
118
  norm3 = np.inf
@@ -95,11 +120,25 @@ class FuzzHandler(ABC):
95
120
  if chunk_origin.nelement() == 0:
96
121
  break
97
122
  chunk_perturbed = perturbed_output_chunks[i]
98
- ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
- TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
- ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
- TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
- norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
123
+ ratio_tensor1 = TorchC.where(
124
+ TorchC.abs(chunk_perturbed) > abs_tol,
125
+ TorchC.div(
126
+ TorchC.clamp(chunk_origin, min=abs_tol),
127
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
128
+ ),
129
+ 1,
130
+ )
131
+ ratio_tensor2 = TorchC.where(
132
+ TorchC.abs(chunk_origin) > abs_tol,
133
+ TorchC.div(
134
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
135
+ TorchC.clamp(chunk_origin, min=abs_tol),
136
+ ),
137
+ 1,
138
+ )
139
+ norm_values = TorchC.stack(
140
+ [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
141
+ )
103
142
  max_ratio1, max_ratio2 = norm_values.tolist()
104
143
  norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
144
  norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
@@ -126,13 +165,13 @@ class FuzzHandler(ABC):
126
165
  if self.params.fuzz_stage == Const.BACKWARD:
127
166
  abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
167
  else:
129
- abs_tol = abs_tol ** 0.5
168
+ abs_tol = abs_tol**0.5
130
169
  return self.get_ratio_from_specific_norm(
131
170
  origin_output, perturbed_output, norm_type, abs_tol
132
171
  )
133
172
 
134
173
  def npu_compare(
135
- self, origin_output, perturbed_output
174
+ self, origin_output, perturbed_output
136
175
  ) -> Tuple[bool, Optional[float]]:
137
176
 
138
177
  if isinstance(perturbed_output, int):
@@ -189,7 +228,7 @@ class FuzzHandler(ABC):
189
228
  max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
229
  )
191
230
  data_params.is_consistent = (
192
- is_consistent and data_params.is_consistent
231
+ is_consistent and data_params.is_consistent
193
232
  )
194
233
  if not is_consistent and data_params.grad_unequal_flag:
195
234
  self.unequal_rows.append(
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from msprobe.pytorch.free_benchmark import logger
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
17
  from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
18
  from msprobe.pytorch.free_benchmark.common.enums import HandlerType
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from typing import Any
3
18
 
@@ -118,8 +133,10 @@ class PreheatHandler(FuzzHandler):
118
133
  """
119
134
  # 每一步样本数
120
135
  total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
- sample_count_per_step = self._get_sample_count_per_step()
122
136
  need_sample_set = set()
137
+ if total_count == 0:
138
+ return need_sample_set
139
+ sample_count_per_step = self._get_sample_count_per_step()
123
140
  prehead_step = self.params.preheat_config.get("preheat_step")
124
141
  for i in range(1, sample_count_per_step + 1):
125
142
  count = (prehead_step * (i - 1) + self.params.step) % total_count
@@ -136,9 +153,7 @@ class PreheatHandler(FuzzHandler):
136
153
 
137
154
  def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
155
  con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
- incon_ratio = [
140
- ratio for ratio, is_consistent in compare_result if not is_consistent
141
- ]
156
+ incon_ratio = [ratio for ratio, is_consistent in compare_result if not is_consistent]
142
157
  old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
158
  new_thd = old_thd
144
159
  # 正例负例都存在
@@ -1,4 +1,18 @@
1
- from msprobe.pytorch.common.utils import logger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
3
17
  from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
4
18
  npu_confusion_transpose_backward
@@ -12,7 +26,8 @@ from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_
12
26
  from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
13
27
  from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
14
28
  npu_scaled_masked_softmax_backward
15
- from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
29
+ from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
30
+ from msprobe.pytorch.common.utils import logger
16
31
 
17
32
 
18
33
  class Register(dict):
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.core.data_dump.scope import BaseScope
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
23
+ from msprobe.pytorch.hook_module.api_registry import api_register
24
+ from msprobe.pytorch.service import torch_version_above_or_equal_2
25
+
26
+ hook_handle_list = []
27
+
28
+
29
+ def module_dump(module, dump_name):
30
+ if not isinstance(module, nn.Module):
31
+ logger.error("The parameter module in module_dump must be a Module subclass.")
32
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
33
+ if not isinstance(dump_name, str):
34
+ logger.error("The parameter dump_name in module_dump must be a str type.")
35
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
36
+
37
+ api_register.api_originality()
38
+ register_hook(module, dump_name)
39
+
40
+
41
+ def module_dump_end():
42
+ api_register.api_modularity()
43
+ remove_hook()
44
+ hook_handle_list.clear()
45
+
46
+
47
+ def register_hook(module, dump_name):
48
+ prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP
49
+
50
+ pdg = PrecisionDebugger()
51
+ _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \
52
+ pdg.service.build_hook(BaseScope.Module_Type_Module, prefix)
53
+
54
+ if torch_version_above_or_equal_2:
55
+ forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
56
+ hook_handle_list.append(forward_hook_handle)
57
+ else:
58
+ pdg.service.check_register_full_backward_hook(module)
59
+ full_backward_hook_handle = module.register_full_backward_hook(
60
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
61
+ forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
62
+ hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle])
63
+ pdg.service.check_register_full_backward_hook(module)
64
+ full_backward_hook_handle = module.register_full_backward_hook(backward_hook)
65
+
66
+ forward_pre_hook_handle = module.register_forward_pre_hook(
67
+ pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
68
+ forward_hook_handle = module.register_forward_hook(
69
+ pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
70
+ hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle])
71
+
72
+ if torch_version_above_or_equal_2:
73
+ backward_pre_hook_handle = module.register_full_backward_pre_hook(
74
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
75
+ pdg.service.check_register_full_backward_hook(module)
76
+ full_backward_hook_handle = module.register_full_backward_hook(
77
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
78
+ hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle])
79
+
80
+
81
+ def remove_hook():
82
+ for hook_handle in hook_handle_list:
83
+ if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
84
+ hook_handle.remove()
@@ -100,7 +100,7 @@ class CSV_max(CsvItem):
100
100
 
101
101
 
102
102
  @register_csv_item(GradConst.MIN)
103
- class CSV_max(CsvItem):
103
+ class CSV_min(CsvItem):
104
104
  def generate_csv_header(csv_header_input):
105
105
  return ["min"]
106
106
 
@@ -110,7 +110,7 @@ class CSV_max(CsvItem):
110
110
 
111
111
 
112
112
  @register_csv_item(GradConst.NORM)
113
- class CSV_max(CsvItem):
113
+ class CSV_norm(CsvItem):
114
114
  def generate_csv_header(csv_header_input):
115
115
  return ["norm"]
116
116
 
@@ -1 +1,16 @@
1
- from .wrap_functional import remove_dropout
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .wrap_functional import remove_dropout
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import torch
19
17
  import torch.distributed as dist
@@ -107,7 +105,14 @@ class ApiRegistry:
107
105
  if not is_gpu:
108
106
  self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
109
107
 
110
- def initialize_hook(self, hook):
108
+ def initialize_hook(self, hook, online_run_ut=False):
109
+ """
110
+ initialize_hook
111
+ Args:
112
+ hook (_type_): initialize_hook
113
+ online_run_ut (bool): default False, whether online run_ut or not.
114
+ If online_run_ut is True, the hook will not wrap the aten ops.
115
+ """
111
116
  self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr)
112
117
  wrap_tensor.wrap_tensor_ops_and_bind(hook)
113
118
  for attr_name in dir(wrap_tensor.HOOKTensor):
@@ -137,7 +142,7 @@ class ApiRegistry:
137
142
  self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP,
138
143
  attr_name)
139
144
 
140
- if torch_version_above_2:
145
+ if torch_version_above_2 and not online_run_ut:
141
146
  self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr)
142
147
  wrap_aten.wrap_aten_ops_and_bind(hook)
143
148
  for attr_name in dir(wrap_aten.HOOKAtenOP):
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import functools
19
17
  import threading
@@ -58,12 +56,12 @@ class HOOKModule(nn.Module):
58
56
  self.register_forward_hook(forward_hook)
59
57
  self.register_backward_hook(backward_hook)
60
58
 
61
- def __call__(self, *input, **kwargs):
59
+ def __call__(self, *args, **kwargs):
62
60
  changed = False
63
61
  if not self.stop_hook:
64
62
  HOOKModule.inner_stop_hook[self.current_thread] = True
65
63
  changed = True
66
- result = self._call_func(*input, **kwargs)
64
+ result = self._call_func(*args, **kwargs)
67
65
  if changed:
68
66
  HOOKModule.inner_stop_hook[self.current_thread] = False
69
67
  return result
@@ -72,28 +70,28 @@ class HOOKModule(nn.Module):
72
70
  def reset_module_stats(cls):
73
71
  cls.module_count = {}
74
72
 
75
- def _call_func(self, *input, **kwargs):
73
+ def _call_func(self, *args, **kwargs):
76
74
  full_backward_hooks, non_full_backward_hooks = [], []
77
75
  if len(self._backward_hooks) > 0:
78
76
  full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
79
77
  for hook in self._forward_pre_hooks.values():
80
- result_input, result_kwargs = hook(self, input, kwargs)
81
- if result_input is not None:
82
- if not isinstance(result_input, tuple):
83
- result_input = (result_input,)
84
- input = result_input
78
+ result_args, result_kwargs = hook(self, args, kwargs)
79
+ if result_args is not None:
80
+ if not isinstance(result_args, tuple):
81
+ result_args = (result_args,)
82
+ args = result_args
85
83
  if result_kwargs is not None:
86
84
  kwargs = result_kwargs
87
85
  bw_hook = None
88
86
  if len(full_backward_hooks) > 0:
89
87
  bw_hook = full_hooks.BackwardHook(self, full_backward_hooks)
90
- input = bw_hook.setup_input_hook(input)
88
+ args = bw_hook.setup_input_hook(args)
91
89
  if torch._C._get_tracing_state():
92
- result = self._slow_forward(*input, **kwargs)
90
+ result = self._slow_forward(*args, **kwargs)
93
91
  else:
94
- result = self.forward(*input, **kwargs)
92
+ result = self.forward(*args, **kwargs)
95
93
  for hook in self._forward_hooks.values():
96
- hook_result = hook(self, input, kwargs, result)
94
+ hook_result = hook(self, args, kwargs, result)
97
95
  if hook_result is not None:
98
96
  result = hook_result
99
97
  if bw_hook:
@@ -116,5 +114,5 @@ class HOOKModule(nn.Module):
116
114
  wrapper = functools.partial(hook, self)
117
115
  functools.update_wrapper(wrapper, hook)
118
116
  grad_fn.register_hook(wrapper)
119
- self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
117
+ self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
120
118
  return result
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  from msprobe.core.common.file_utils import load_yaml
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -24,6 +22,7 @@ from msprobe.core.common.const import Const
24
22
  from msprobe.core.common.file_utils import load_yaml
25
23
  from msprobe.pytorch.function_factory import npu_custom_grad_functions
26
24
 
25
+
27
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
28
27
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
28
  ops = load_yaml(yaml_path)
@@ -50,6 +49,8 @@ class AtenOPTemplate(HOOKModule):
50
49
  def __init__(self, op, hook, need_hook=True):
51
50
  if isinstance(op, torch._ops.OpOverloadPacket):
52
51
  op_name_ = op._qualified_op_name.split("::")[-1]
52
+ elif isinstance(op, str):
53
+ op_name_ = str(op)
53
54
  else:
54
55
  op_name_ = op.name().split("::")[-1]
55
56
  overload_name = op._overloadname
@@ -76,13 +77,13 @@ class AtenOPTemplate(HOOKModule):
76
77
 
77
78
 
78
79
  class AtenOPPacketTemplate():
79
- def __init__(self, opPacket, hook):
80
- self.opPacket = opPacket
80
+ def __init__(self, op_packet, hook):
81
+ self.op_packet = op_packet
81
82
  self.hook = hook
82
83
 
83
84
  def __getattr__(self, key):
84
85
  try:
85
- attr = getattr(self.opPacket, key)
86
+ attr = getattr(self.op_packet, key)
86
87
  except AttributeError as e:
87
88
  raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e
88
89
  if isinstance(attr, torch._ops.OpOverload):
@@ -92,10 +93,10 @@ class AtenOPPacketTemplate():
92
93
 
93
94
  @torch_device_guard
94
95
  def __call__(self, *args, **kwargs):
95
- return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs)
96
+ return AtenOPTemplate(self.op_packet, self.hook)(*args, **kwargs)
96
97
 
97
98
  def overloads(self):
98
- return self.opPacket.overloads()
99
+ return self.op_packet.overloads()
99
100
 
100
101
 
101
102
  def wrap_aten_op(op, hook):
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  from functools import wraps
@@ -23,6 +21,7 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
23
21
  from msprobe.pytorch.common.utils import torch_device_guard
24
22
  from msprobe.core.common.const import Const
25
23
  from msprobe.core.common.file_utils import load_yaml
24
+ from msprobe.core.common.inplace_op_checker import InplaceOpChecker
26
25
 
27
26
 
28
27
  cur_path = os.path.dirname(os.path.realpath(__file__))
@@ -50,7 +49,7 @@ class DistributedOPTemplate(HOOKModule):
50
49
  self.op_name_ = op_name
51
50
  self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
52
51
  super().__init__(build_hook)
53
- if not self.stop_hook and self.op_name_ in Const.INPLACE_LIST:
52
+ if not self.stop_hook and InplaceOpChecker.check(self.op_name_, InplaceOpChecker.OP_DISTRIBUTED):
54
53
  self.op_is_inplace = True
55
54
 
56
55
  @torch_device_guard