mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,34 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+
1
18
  import torch
2
- from torch.utils.data import dataloader
3
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
- from msprobe.pytorch.service import Service
5
- from msprobe.pytorch.common.log import logger
6
- from msprobe.pytorch.pt_config import parse_json_config
19
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
7
20
  from msprobe.core.common.exceptions import MsprobeException
8
- from msprobe.core.common.const import Const
21
+ from msprobe.core.common.file_utils import FileChecker
22
+ from msprobe.core.common.utils import get_real_step_or_rank
23
+ from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
9
25
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
+ from msprobe.pytorch.pt_config import parse_json_config
27
+ from msprobe.pytorch.service import Service
28
+ from torch.utils.data import dataloader
29
+
30
+ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
31
+ "dump_path", "level", "model"])
10
32
 
11
33
 
12
34
  class PrecisionDebugger:
@@ -30,20 +52,26 @@ class PrecisionDebugger:
30
52
  step=None,
31
53
  ):
32
54
  if not hasattr(self, "initialized"):
55
+ config_params = ConfigParameters(config_path,
56
+ task,
57
+ dump_path,
58
+ level,
59
+ model)
60
+ self.check_input_params(config_params)
61
+
33
62
  self.api_origin = False
34
63
  self.initialized = True
35
- self.model = self.check_model_valid(model)
64
+ self.model = model
36
65
  common_config, task_config = parse_json_config(config_path, task)
37
- self.task = common_config.task
66
+ self.task = task if task else common_config.task
38
67
  if self.task == Const.GRAD_PROBE:
39
68
  self.gm = GradientMonitor(common_config, task_config)
40
69
  return
41
70
  if step:
42
- common_config.step = step
71
+ common_config.step = get_real_step_or_rank(step, Const.STEP)
43
72
  self.config = DebuggerConfig(
44
73
  common_config, task_config, task, dump_path, level
45
74
  )
46
- self.config.check_model(self.model)
47
75
  self.service = Service(self.config)
48
76
  self.enable_dataloader = self.config.enable_dataloader
49
77
  if self.enable_dataloader:
@@ -55,20 +83,40 @@ class PrecisionDebugger:
55
83
  return self._instance
56
84
 
57
85
  @staticmethod
58
- def check_model_valid(model):
59
- if not model or isinstance(model, torch.nn.Module):
60
- return model
61
- raise MsprobeException(
62
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
- )
86
+ def check_input_params(args):
87
+ if args.config_path is not None:
88
+ if not isinstance(args.config_path, str):
89
+ raise MsprobeException(
90
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
91
+ file_checker = FileChecker(
92
+ file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
93
+ file_checker.common_check()
94
+
95
+ if args.task is not None and args.task not in Const.TASK_LIST:
96
+ raise MsprobeException(
97
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
98
+
99
+ if args.dump_path is not None:
100
+ if not isinstance(args.dump_path, str):
101
+ raise MsprobeException(
102
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
103
+
104
+ if args.level is not None and args.level not in Const.LEVEL_LIST:
105
+ raise MsprobeException(
106
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
+
108
+ if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
+ raise MsprobeException(
110
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
64
111
 
65
112
  @classmethod
66
- def start(cls):
113
+ def start(cls, model=None):
67
114
  instance = cls._instance
115
+ if not instance:
116
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
68
117
  if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
118
  return
70
- if not instance:
71
- raise Exception("No instance of PrecisionDebugger found.")
119
+ instance.config.check_model(instance, model)
72
120
  if instance.enable_dataloader:
73
121
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
122
  else:
@@ -85,10 +133,10 @@ class PrecisionDebugger:
85
133
  @classmethod
86
134
  def stop(cls):
87
135
  instance = cls._instance
136
+ if not instance:
137
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
88
138
  if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
139
  return
90
- if not instance:
91
- raise Exception("PrecisionDebugger instance is not created.")
92
140
  if instance.enable_dataloader:
93
141
  logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
142
  else:
@@ -96,16 +144,16 @@ class PrecisionDebugger:
96
144
 
97
145
  @classmethod
98
146
  def step(cls):
147
+ if not cls._instance:
148
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
99
149
  if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
150
  return
101
- if not cls._instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
103
151
  cls._instance.service.step()
104
152
 
105
153
  @classmethod
106
154
  def monitor(cls, model):
107
155
  if not cls._instance:
108
- raise Exception("PrecisionDebugger instance is not created.")
156
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
109
157
  if cls._instance.task != Const.GRAD_PROBE:
110
158
  return
111
159
  cls._instance.gm.monitor(model)
@@ -1,8 +1,23 @@
1
- from msprobe.pytorch.common.log import logger
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["FreeBenchmarkCheck", "UnequalRow"]
17
+
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import FreeBenchmarkException
20
+ from msprobe.pytorch.common.log import logger
4
21
 
5
- from .main import FreeBenchmarkCheck
6
22
  from .common.params import UnequalRow
7
-
8
- __all__ = [FreeBenchmarkCheck, UnequalRow]
23
+ from .main import FreeBenchmarkCheck
@@ -1,3 +1,6 @@
1
+ from msprobe.core.common.const import Const
2
+
3
+
1
4
  class PerturbationMode:
2
5
  ADD_NOISE = "add_noise"
3
6
  CHANGE_VALUE = "change_value"
@@ -35,3 +38,28 @@ class FuzzLevel:
35
38
  BASE_LEVEL = "L1"
36
39
  ADV_LEVEL = "L2"
37
40
  REAL_LEVEL = "L3"
41
+
42
+
43
+ class PytorchFreeBenchmarkConst:
44
+ PERTURBATION_MODE_LIST = [
45
+ PerturbationMode.ADD_NOISE,
46
+ PerturbationMode.CHANGE_VALUE,
47
+ PerturbationMode.IMPROVE_PRECISION,
48
+ PerturbationMode.NO_CHANGE,
49
+ PerturbationMode.BIT_NOISE,
50
+ PerturbationMode.TO_CPU,
51
+ ]
52
+ DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION
53
+ DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU]
54
+ DEFAULT_DEVICE = DeviceType.NPU
55
+ HANDLER_LIST = [HandlerType.CHECK, HandlerType.FIX]
56
+ DEFAULT_HANDLER = HandlerType.CHECK
57
+ FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL]
58
+ DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL
59
+ FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD]
60
+ FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU]
61
+ DEFAULT_FUZZ_STAGE = Const.FORWARD
62
+ DEFAULT_PREHEAT_STEP = 15
63
+ DEFAULT_MAX_SAMPLE = 20
64
+ CPU_MODE_LIST = [PerturbationMode.TO_CPU]
65
+ FIX_STAGE_LIST = [Const.FORWARD]
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from dataclasses import dataclass
2
17
  from typing import Any, Callable, Dict, List, Optional, Tuple
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
3
18
 
@@ -75,7 +90,8 @@ class Tools:
75
90
  )
76
91
  return type(origin)(result)
77
92
  return origin
78
-
93
+
94
+
79
95
  class TorchC:
80
96
  sum = torch._C._VariableFunctionsClass.sum
81
97
  isinf = torch._C._VariableFunctionsClass.isinf
@@ -1,8 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.core.common.exceptions import FreeBenchmarkException
3
18
  from msprobe.pytorch.free_benchmark import logger
4
19
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal
20
+ from msprobe.pytorch.free_benchmark.common.params import (
21
+ DataParams,
22
+ HandlerParams,
23
+ data_pre_deal,
24
+ )
6
25
  from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
7
26
  from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
8
27
  FuzzHandlerFactory,
@@ -84,7 +103,7 @@ class GradSaver:
84
103
  if self.perturbed_grad_input is None:
85
104
  raise FreeBenchmarkException(
86
105
  FreeBenchmarkException.InvalidGrad,
87
- f"grad not exists : {self.api_name}."
106
+ f"grad not exists : {self.api_name}.",
88
107
  )
89
108
  with torch.no_grad():
90
109
  perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
@@ -94,7 +113,7 @@ class GradSaver:
94
113
  raise FreeBenchmarkException(
95
114
  FreeBenchmarkException.InvalidGrad,
96
115
  f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
97
- f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
116
+ f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
98
117
  )
99
118
  return perturbed_grad
100
119
 
@@ -150,8 +169,8 @@ class GradSaver:
150
169
  else:
151
170
  _real_input.append(object_)
152
171
  kwargs = self.kwargs.copy()
153
- if 'inplace' in kwargs:
154
- kwargs['inplace'] = False
172
+ if "inplace" in kwargs:
173
+ kwargs["inplace"] = False
155
174
  return self.origin_func(*_real_input, **kwargs)
156
175
 
157
176
  _, grad_input = torch.autograd.functional.vjp(
@@ -159,12 +178,14 @@ class GradSaver:
159
178
  )
160
179
  return grad_input
161
180
 
162
- def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args):
181
+ def calculate_perturbed_grad_input(
182
+ self, grad_output, need_grad_tensors, inner_args
183
+ ):
163
184
  data_params = data_pre_deal(
164
185
  self.handler_params.api_name,
165
186
  self.get_grad_input_from_vjp,
166
187
  [need_grad_tensors, grad_output, inner_args],
167
- {}
188
+ {},
168
189
  )
169
190
  layer = LayerFactory.create(
170
191
  self.handler_params.api_name,
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
 
3
18
  import torch
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC
2
17
 
3
18
  import torch
@@ -36,9 +51,9 @@ class FreeBenchmarkCheck(ABC):
36
51
 
37
52
  def update_iter(self, update_iter):
38
53
  self.current_iter = update_iter
39
-
54
+
40
55
  def if_fix(self):
41
- if self.config.handler_type==HandlerType.FIX:
56
+ if self.config.handler_type == HandlerType.FIX:
42
57
  return True
43
58
  return False
44
59
 
@@ -73,9 +88,9 @@ class FreeBenchmarkCheck(ABC):
73
88
  layer.handle(data_params)
74
89
  handler_params = make_handler_params(name, self.config, self.current_iter)
75
90
  handler = FuzzHandlerFactory.create(handler_params)
76
- perturbed_output = handler.handle(data_params)
91
+ perturbed_output = handler.handle(data_params)
77
92
  return perturbed_output, handler.get_unequal_rows()
78
-
93
+
79
94
  def backward(self, name, module, grad_output):
80
95
 
81
96
  if not self.config.fuzz_stage == Const.BACKWARD:
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
  from typing import Any
3
18
 
@@ -1,14 +1,29 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
17
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
3
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
4
- ImprovePrecisionLayer,
5
- )
6
18
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
7
19
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
8
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
9
20
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
10
21
  ChangeValueLayer,
11
22
  )
23
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
24
+ ImprovePrecisionLayer,
25
+ )
26
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
12
27
  from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
13
28
 
14
29
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -54,10 +69,19 @@ class ChangeValueLayer(NpuBaseLayer):
54
69
  """
55
70
  判断是否需要添加扰动, 首尾值交换
56
71
  """
57
- if tensor_obj.size(0) < 2:
72
+ # 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
73
+ if tensor_obj.ndim > 1:
74
+ if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
75
+ logger.info_on_rank_0(
76
+ f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
77
+ f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
78
+ )
79
+ return False
80
+ # 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
81
+ elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
58
82
  logger.info_on_rank_0(
59
83
  f"[msprobe] Free Benchmark: For {self.api_name}, "
60
- f"size 0 must greater than 1. Cancel change value."
84
+ f"0-dimension must greater than 1. Cancel change value."
61
85
  )
62
86
  return False
63
87
  return True
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.core.common.const import Const
3
18
  from msprobe.pytorch.free_benchmark import logger
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import abstractmethod
2
17
  from typing import Any
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.params import DataParams