mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,30 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
17
+ from collections import defaultdict
2
18
 
3
19
  from mindspore import Tensor
4
- from mindspore.common.api import _MindsporeFunctionExecutor
5
20
  from mindspore._c_expression import PyNativeExecutor_
21
+ from mindspore.common.api import _MindsporeFunctionExecutor
6
22
 
7
23
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
8
- from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
24
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
9
25
  from msprobe.core.common.const import Const
26
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
27
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
10
28
 
11
29
 
12
30
  def dump_jit(name, in_feat, out_feat, is_forward):
@@ -17,19 +35,27 @@ def dump_jit(name, in_feat, out_feat, is_forward):
17
35
  result = ori_args[0:index]
18
36
  else:
19
37
  result = "JitFunction"
20
- if is_forward:
21
- name_template = "Jit." + result + ".forward"
22
- else:
23
- name_template = "Jit." + result + ".backward"
24
38
  if JitDump.need_dump():
25
- JitDump.data_collector.update_api_or_module_name(name_template)
26
- module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
27
- JitDump.data_collector.forward_data_collect(name_template, {}, pid, module_input_output)
39
+ if is_forward:
40
+ JitDump.jit_count[result] += 1
41
+ name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
42
+ Const.FORWARD
43
+ JitDump.data_collector.update_api_or_module_name(name_template)
44
+ module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
45
+ JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
46
+ else:
47
+ name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
48
+ Const.BACKWARD
49
+ JitDump.data_collector.update_api_or_module_name(name_template)
50
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat ,grad_output=out_feat)
51
+ JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
28
52
 
29
53
 
30
54
  class JitDump(_MindsporeFunctionExecutor):
31
55
  dump_config = None
32
56
  jit_enable = False
57
+ jit_dump_switch = True
58
+ jit_count = defaultdict(int)
33
59
 
34
60
  def __init__(self, *args, **kwargs):
35
61
  super().__init__(*args, **kwargs)
@@ -38,11 +64,9 @@ class JitDump(_MindsporeFunctionExecutor):
38
64
  def __call__(self, *args, **kwargs):
39
65
  api_register.api_set_ori_func()
40
66
  out = super().__call__(*args, **kwargs)
41
- if isinstance(args[0], Tensor):
42
- dump_jit({}, args, out, True)
43
- else:
44
- dump_jit(args[0], args[1:], out, True)
45
- JitDump.jit_enable = True
67
+ if JitDump.jit_dump_switch and len(args) > 0:
68
+ dump_jit(args[0], args, out, True)
69
+ JitDump.jit_enable = True
46
70
  api_register.api_set_hook_func()
47
71
  return out
48
72
 
@@ -62,11 +86,11 @@ class JitDump(_MindsporeFunctionExecutor):
62
86
  return False
63
87
  return True
64
88
 
65
- def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
66
- if JitDump.jit_enable:
89
+ def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
90
+ if JitDump.jit_dump_switch and JitDump.jit_enable:
67
91
  api_register.api_set_ori_func()
68
92
  output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
69
- if JitDump.jit_enable:
93
+ if JitDump.jit_dump_switch and JitDump.jit_enable:
70
94
  dump_jit(obj, args, None, False)
71
95
  api_register.api_set_hook_func()
72
96
  return output
@@ -1,8 +1,24 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import json
3
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
4
- from msprobe.mindspore.common.log import logger
17
+ import os
18
+
5
19
  from msprobe.core.common.file_utils import FileOpen, create_directory
20
+ from msprobe.mindspore.common.log import logger
21
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
6
22
 
7
23
 
8
24
  class KernelGraphDump:
@@ -1,10 +1,25 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import json
17
+ import os
3
18
 
4
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
5
- from msprobe.mindspore.common.log import logger
6
- from msprobe.core.common.file_utils import FileOpen, create_directory
7
19
  from msprobe.core.common.const import Const
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory
21
+ from msprobe.mindspore.common.log import logger
22
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
8
23
 
9
24
 
10
25
  class KernelKbykDump:
@@ -1,17 +1,32 @@
1
- import os
2
- import inspect
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  import importlib
17
+ import inspect
18
+ import os
4
19
 
5
20
  import mindspore as ms
6
21
  from mindspore.communication import comm_func
7
22
 
8
- from msprobe.core.common.file_utils import load_yaml, check_path_length
9
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import check_path_length, load_yaml
10
25
  from msprobe.mindspore.common.const import Const as MsConst
11
26
  from msprobe.mindspore.common.const import FreeBenchmarkConst
12
- from msprobe.mindspore.free_benchmark.common.config import Config
13
27
  from msprobe.mindspore.common.log import logger
14
28
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
29
+ from msprobe.mindspore.free_benchmark.common.config import Config
15
30
  from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
16
31
 
17
32
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Optional, Any, Tuple, Dict, Callable
2
17
 
3
18
 
@@ -1,14 +1,28 @@
1
- from typing import Any
2
- from typing import Optional
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  from dataclasses import dataclass
17
+ from typing import Any, Optional
4
18
 
5
19
  import mindspore as ms
6
20
  from mindspore import Tensor
7
21
 
8
- from msprobe.mindspore.runtime import Runtime
9
22
  from msprobe.mindspore.common.const import FreeBenchmarkConst
10
- from .config import Config
11
- from .handler_params import HandlerParams
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
24
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.runtime import Runtime
12
26
 
13
27
 
14
28
  class Tools:
@@ -1,6 +1,20 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.common.const import Const, FreeBenchmarkConst
1
17
  from msprobe.mindspore.free_benchmark.common.config import Config
2
- from msprobe.mindspore.common.const import Const
3
- from msprobe.mindspore.common.const import FreeBenchmarkConst
4
18
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
5
19
  from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
6
20
  from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
@@ -1,16 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import sys
3
18
  import traceback
4
19
  from functools import wraps
5
- from typing import Tuple, Dict, List
20
+ from typing import Dict, List, Tuple
6
21
 
7
22
  from mindspore import ops
8
23
 
9
- from msprobe.mindspore.runtime import Runtime
10
24
  from msprobe.mindspore.common.log import logger
11
25
  from msprobe.mindspore.free_benchmark.common.config import Config
12
26
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
13
- from .dec_forward import ForwardSelfChecker
27
+ from msprobe.mindspore.free_benchmark.decorator.dec_forward import ForwardSelfChecker
28
+ from msprobe.mindspore.runtime import Runtime
14
29
 
15
30
 
16
31
  def decorate(original_func, decorate_func, api_name=None):
@@ -1,14 +1,29 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from abc import ABC, abstractmethod
3
- from typing import Any, Tuple, Optional
18
+ from typing import Any, Optional, Tuple
4
19
 
5
20
  import mindspore as ms
6
21
  from mindspore import Tensor, ops
7
22
 
8
- from msprobe.mindspore.common.log import logger
9
- from msprobe.mindspore.free_benchmark.common.utils import Tools
10
23
  from msprobe.mindspore.common.const import FreeBenchmarkConst
24
+ from msprobe.mindspore.common.log import logger
11
25
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
26
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
12
27
 
13
28
 
14
29
  class BaseHandler(ABC):
@@ -1,14 +1,29 @@
1
- from typing import Any
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from dataclasses import asdict
17
+ from typing import Any
3
18
 
4
19
  from mindspore import Tensor, ops
5
20
 
21
+ from msprobe.core.data_dump.json_writer import DataWriter
6
22
  from msprobe.mindspore.common.log import logger
7
23
  from msprobe.mindspore.free_benchmark.common.config import Config
8
- from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
9
24
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
10
25
  from msprobe.mindspore.free_benchmark.common.utils import make_unequal_row
11
- from msprobe.core.data_dump.json_writer import DataWriter
26
+ from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler
12
27
 
13
28
 
14
29
  class CheckHandler(BaseHandler):
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from mindspore import Tensor
@@ -1,8 +1,23 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
1
17
  from msprobe.mindspore.common.log import logger
2
18
  from msprobe.mindspore.free_benchmark.common.config import Config
3
- from msprobe.mindspore.common.const import FreeBenchmarkConst
4
- from .check_handler import CheckHandler
5
- from .fix_handler import FixHandler
19
+ from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler
20
+ from msprobe.mindspore.free_benchmark.handler.fix_handler import FixHandler
6
21
 
7
22
 
8
23
  class HandlerFactory:
@@ -1,11 +1,26 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from mindspore import Tensor, ops
4
19
 
20
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
5
21
  from msprobe.mindspore.common.log import logger
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
22
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
8
- from msprobe.mindspore.common.const import FreeBenchmarkConst
23
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
9
24
 
10
25
 
11
26
  class AddNoisePerturbation(BasePerturbation):
@@ -43,25 +58,25 @@ class AddNoisePerturbation(BasePerturbation):
43
58
 
44
59
  return inputs
45
60
 
46
- def _get_noise(self, input):
61
+ def _get_noise(self, tensor):
47
62
  """
48
63
  得到要添加的噪声值
49
64
 
50
65
  """
51
66
  if self.is_fuzzed:
52
67
  return False
53
- if not ops.is_floating_point(input) or ops.numel(input) == 0:
68
+ if not ops.is_floating_point(tensor) or ops.numel(tensor) == 0:
54
69
  return False
55
70
 
56
- pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype)
71
+ pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
57
72
  if not pert_value:
58
73
  return False
59
74
  else:
60
75
  self.perturbation_value = pert_value
61
76
 
62
- max_val = ops.max(ops.abs(input))[0].item()
77
+ max_val = ops.max(ops.abs(tensor))[0].item()
63
78
  if max_val < pert_value:
64
79
  return False
65
80
 
66
- noise = ops.full(input.shape, self.perturbation_value, dtype=input.dtype)
81
+ noise = ops.full(tensor.shape, self.perturbation_value, dtype=tensor.dtype)
67
82
  return noise
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import numpy as np
4
19
  from mindspore import Tensor, ops
5
20
 
6
- from msprobe.mindspore.common.log import logger
7
21
  from msprobe.mindspore.common.const import FreeBenchmarkConst
22
+ from msprobe.mindspore.common.log import logger
8
23
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
9
24
  from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
10
25
 
@@ -45,19 +60,19 @@ class BitNoisePerturbation(BasePerturbation):
45
60
  params.args = args
46
61
  return self.get_fuzzed_result(params)
47
62
 
48
- def _get_bit_len_type(self, input):
63
+ def _get_bit_len_type(self, tensor):
49
64
  if self.is_fuzzed:
50
65
  return False
51
- if not isinstance(input, Tensor) or not ops.is_floating_point(input) or \
52
- input.numel() == 0:
66
+ if not isinstance(tensor, Tensor) or not ops.is_floating_point(tensor) or \
67
+ tensor.numel() == 0:
53
68
  return False
54
- bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(input.dtype)
69
+ bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(tensor.dtype)
55
70
  if not bit_len_type:
56
71
  return False
57
- pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype)
72
+ pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
58
73
  if not pert_value:
59
74
  return False
60
- max_val = ops.max(ops.abs(input))[0].item()
75
+ max_val = ops.max(ops.abs(tensor))[0].item()
61
76
  if max_val < pert_value:
62
77
  return False
63
78
  return bit_len_type
@@ -1,14 +1,39 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
- from mindspore import Tensor
18
+ from mindspore import Tensor, ops
4
19
 
5
20
  from msprobe.mindspore.common.log import logger
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
21
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
22
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
8
23
 
9
24
 
10
25
  class ExchangeValuePerturbation(BasePerturbation):
11
26
 
27
+ @staticmethod
28
+ def _check_tensor_shape(inputs):
29
+ dims = len(inputs.shape)
30
+ if dims == 1 and inputs.shape[0] > 1:
31
+ return True
32
+ if dims > 1 and inputs.shape[1] > 0:
33
+ if inputs.shape[0] > 1 or inputs.shape[1] > 1:
34
+ return True
35
+ return False
36
+
12
37
  def handle(self, params: HandlerParams) -> Any:
13
38
  """
14
39
  返回首尾交换后的api输出
@@ -25,22 +50,23 @@ class ExchangeValuePerturbation(BasePerturbation):
25
50
  返回首尾交换后的api输入
26
51
 
27
52
  """
28
- if isinstance(inputs, Tensor):
29
- if not self.is_fuzzed and len(inputs.shape) > 0 and inputs.shape[0] > 1:
30
- result = inputs.copy()
31
- if len(inputs.shape) == 1:
32
- first_element = inputs[0]
33
- last_element = inputs[-1]
34
- result[0] = last_element
35
- result[-1] = first_element
36
- else:
37
- first_element = inputs[0][0]
38
- last_element = inputs[-1][-1]
39
- result[0][0] = last_element
40
- result[-1][-1] = first_element
41
-
42
- self.is_fuzzed = True
43
- return result
53
+ if isinstance(inputs, Tensor) and ops.is_floating_point(inputs):
54
+ if self.is_fuzzed or not self._check_tensor_shape(inputs):
55
+ return inputs
56
+ result = inputs.copy()
57
+ if len(inputs.shape) == 1:
58
+ first_element = inputs[0]
59
+ last_element = inputs[-1]
60
+ result[0] = last_element
61
+ result[-1] = first_element
62
+ else:
63
+ first_element = inputs[0][0]
64
+ last_element = inputs[-1][-1]
65
+ result[0][0] = last_element
66
+ result[-1][-1] = first_element
67
+
68
+ self.is_fuzzed = True
69
+ return result
44
70
 
45
71
  if isinstance(inputs, dict):
46
72
  return {k: self.exchange_value(v) for k, v in inputs.items()}