mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.core.common.exceptions import FreeBenchmarkException
3
18
  from msprobe.pytorch.free_benchmark import logger
4
19
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
- from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal
20
+ from msprobe.pytorch.free_benchmark.common.params import (
21
+ DataParams,
22
+ HandlerParams,
23
+ data_pre_deal,
24
+ )
6
25
  from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
7
26
  from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import (
8
27
  FuzzHandlerFactory,
@@ -83,8 +102,13 @@ class GradSaver:
83
102
  def check_grad_input(self, origin_grad, new_grad_index):
84
103
  if self.perturbed_grad_input is None:
85
104
  raise FreeBenchmarkException(
86
- FreeBenchmarkException.InvalidGrad,
87
- f"grad not exists : {self.api_name}."
105
+ FreeBenchmarkException.InvalidPerturbedOutput,
106
+ f"perturbed grad not exists for {self.api_name}.",
107
+ )
108
+ if len(self.perturbed_grad_input) <= new_grad_index:
109
+ raise FreeBenchmarkException(
110
+ FreeBenchmarkException.InvalidPerturbedOutput,
111
+ f"perturbed grad index {new_grad_index} is out of bounds for {self.api_name}.",
88
112
  )
89
113
  with torch.no_grad():
90
114
  perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
@@ -92,9 +116,9 @@ class GradSaver:
92
116
  )
93
117
  if origin_grad.shape != perturbed_grad.shape:
94
118
  raise FreeBenchmarkException(
95
- FreeBenchmarkException.InvalidGrad,
119
+ FreeBenchmarkException.InvalidPerturbedOutput,
96
120
  f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
97
- f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}"
121
+ f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
98
122
  )
99
123
  return perturbed_grad
100
124
 
@@ -145,13 +169,25 @@ class GradSaver:
145
169
  index_ = 0
146
170
  for object_ in inner_args:
147
171
  if object_ is CommonField.HOLD_PLACE:
172
+ if index_ >= len(inputs):
173
+ err_msg = (
174
+ f"[msprobe] Free benchmark: When getting input from vjp, "
175
+ f" the input index ({index_}) is out of bounds ({len(inputs)})."
176
+ )
177
+ logger.error_log_with_exp(
178
+ err_msg,
179
+ FreeBenchmarkException(
180
+ FreeBenchmarkException.InvalidGrad,
181
+ error_info=err_msg,
182
+ ),
183
+ )
148
184
  _real_input.append(inputs[index_])
149
185
  index_ += 1
150
186
  else:
151
187
  _real_input.append(object_)
152
188
  kwargs = self.kwargs.copy()
153
- if 'inplace' in kwargs:
154
- kwargs['inplace'] = False
189
+ if "inplace" in kwargs:
190
+ kwargs["inplace"] = False
155
191
  return self.origin_func(*_real_input, **kwargs)
156
192
 
157
193
  _, grad_input = torch.autograd.functional.vjp(
@@ -159,12 +195,14 @@ class GradSaver:
159
195
  )
160
196
  return grad_input
161
197
 
162
- def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args):
198
+ def calculate_perturbed_grad_input(
199
+ self, grad_output, need_grad_tensors, inner_args
200
+ ):
163
201
  data_params = data_pre_deal(
164
202
  self.handler_params.api_name,
165
203
  self.get_grad_input_from_vjp,
166
204
  [need_grad_tensors, grad_output, inner_args],
167
- {}
205
+ {},
168
206
  )
169
207
  layer = LayerFactory.create(
170
208
  self.handler_params.api_name,
@@ -1,6 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
 
3
18
  import torch
19
+ from msprobe.core.common.utils import recursion_depth_decorator
4
20
  from msprobe.pytorch.free_benchmark import logger
5
21
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
22
  from msprobe.pytorch.free_benchmark.common.utils import TorchC
@@ -52,6 +68,7 @@ class SingleCompare:
52
68
  return False
53
69
  return True
54
70
 
71
+ @recursion_depth_decorator("FreeBenchmark: SingleCompare.compare_seq")
55
72
  def compare_seq(self, actual, golden):
56
73
  if isinstance(golden, torch.Tensor):
57
74
  return self.compare_tensor_seq(actual, golden)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC
2
17
 
3
18
  import torch
@@ -36,9 +51,9 @@ class FreeBenchmarkCheck(ABC):
36
51
 
37
52
  def update_iter(self, update_iter):
38
53
  self.current_iter = update_iter
39
-
54
+
40
55
  def if_fix(self):
41
- if self.config.handler_type==HandlerType.FIX:
56
+ if self.config.handler_type == HandlerType.FIX:
42
57
  return True
43
58
  return False
44
59
 
@@ -73,9 +88,9 @@ class FreeBenchmarkCheck(ABC):
73
88
  layer.handle(data_params)
74
89
  handler_params = make_handler_params(name, self.config, self.current_iter)
75
90
  handler = FuzzHandlerFactory.create(handler_params)
76
- perturbed_output = handler.handle(data_params)
91
+ perturbed_output = handler.handle(data_params)
77
92
  return perturbed_output, handler.get_unequal_rows()
78
-
93
+
79
94
  def backward(self, name, module, grad_output):
80
95
 
81
96
  if not self.config.fuzz_stage == Const.BACKWARD:
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
  from typing import Any
3
18
 
@@ -1,14 +1,29 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
17
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
3
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
4
- ImprovePrecisionLayer,
5
- )
6
18
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer
7
19
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer
8
- from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
9
20
  from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import (
10
21
  ChangeValueLayer,
11
22
  )
23
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import (
24
+ ImprovePrecisionLayer,
25
+ )
26
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer
12
27
  from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer
13
28
 
14
29
 
@@ -1,4 +1,20 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
2
18
  from msprobe.pytorch.free_benchmark import logger
3
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
4
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -11,6 +27,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
11
27
 
12
28
  class AddNoiseLayer(NpuBaseLayer):
13
29
 
30
+ @recursion_depth_decorator("FreeBenchmark: AddNoiseLayer.add_noise")
14
31
  def add_noise(self, tensor_obj):
15
32
  if isinstance(tensor_obj, torch.Tensor):
16
33
  self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
@@ -84,7 +101,7 @@ class AddNoiseLayer(NpuBaseLayer):
84
101
  if max_val < abs_tol:
85
102
  logger.warning_on_rank_0(
86
103
  f"[msprobe] Free Benchmark: For {self.api_name}, "
87
- f"Maximun value is less than the minimun threshold. Cancel add noise."
104
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
88
105
  )
89
106
  return False
90
107
  return True
@@ -1,4 +1,20 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
2
18
  from msprobe.pytorch.free_benchmark import logger
3
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
4
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -16,6 +32,7 @@ class BitNoiseLayer(NpuBaseLayer):
16
32
  self.bit_tail: int = 1
17
33
  self.bit_type = None
18
34
 
35
+ @recursion_depth_decorator("FreeBenchmark: BitNoiseLayer.add_bit_noise")
19
36
  def add_bit_noise(self, tensor_obj):
20
37
  """
21
38
  对输入添加噪声
@@ -64,14 +81,14 @@ class BitNoiseLayer(NpuBaseLayer):
64
81
  判断是否需要添加扰动, bit翻转
65
82
  """
66
83
  if not self.bit_type:
67
- logger.info_on_rank_0(
84
+ logger.warning_on_rank_0(
68
85
  f"[msprobe] Free Benchmark: For {self.api_name}, "
69
86
  f"dtype unsupported. Cancel perturbation."
70
87
  )
71
88
  return False
72
89
  if tensor_obj.numel() == 0:
73
90
  logger.warning_on_rank_0(
74
- f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
91
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
75
92
  f" Cancel adding noise."
76
93
  )
77
94
  return False
@@ -87,9 +104,9 @@ class BitNoiseLayer(NpuBaseLayer):
87
104
  )
88
105
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
89
106
  if max_val < abs_tol:
90
- logger.info_on_rank_0(
107
+ logger.warning_on_rank_0(
91
108
  f"[msprobe] Free Benchmark: For {self.api_name}, "
92
- f"Maximun value is less than the minimun threshold. Cancel add noise."
109
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
93
110
  )
94
111
  return False
95
112
  return True
@@ -1,4 +1,20 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
2
18
  from msprobe.pytorch.free_benchmark import logger
3
19
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -14,6 +30,7 @@ class ChangeValueLayer(NpuBaseLayer):
14
30
  self.head: int = 0
15
31
  self.tail: int = -1
16
32
 
33
+ @recursion_depth_decorator("FreeBenchmark: ChangeValueLayer.change_value")
17
34
  def change_value(self, tensor_obj):
18
35
  """
19
36
  交换张量首尾
@@ -54,10 +71,19 @@ class ChangeValueLayer(NpuBaseLayer):
54
71
  """
55
72
  判断是否需要添加扰动, 首尾值交换
56
73
  """
57
- if tensor_obj.size(0) < 2:
74
+ # 对于维度大于1的张量、要求1维至少大于1且0维和1维至少一个长度大于2
75
+ if tensor_obj.ndim > 1:
76
+ if tensor_obj.size(1) == 0 or (tensor_obj.size(1) < 2 and tensor_obj.size(0) < 2):
77
+ logger.info_on_rank_0(
78
+ f"[msprobe] Free Benchmark: For {self.api_name} with ndim {tensor_obj.ndim}, "
79
+ f"at least one of 0-dimension or 1-dimension greater than 1. Cancel change value."
80
+ )
81
+ return False
82
+ # 不支持维度等于0的张量、对于维度等于1的张量、要求0维长度大于2
83
+ elif tensor_obj.dim() == 0 or tensor_obj.size(0) < 2:
58
84
  logger.info_on_rank_0(
59
85
  f"[msprobe] Free Benchmark: For {self.api_name}, "
60
- f"size 0 must greater than 1. Cancel change value."
86
+ f"0-dimension must greater than 1. Cancel change value."
61
87
  )
62
88
  return False
63
89
  return True
@@ -1,5 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.core.common.const import Const
18
+ from msprobe.core.common.utils import recursion_depth_decorator
3
19
  from msprobe.pytorch.free_benchmark import logger
4
20
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
21
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -11,6 +27,9 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
11
27
 
12
28
  class ImprovePrecisionLayer(NpuBaseLayer):
13
29
 
30
+ @recursion_depth_decorator(
31
+ "FreeBenchmark: ImprovePrecisionLayer.improve_tensor_precision"
32
+ )
14
33
  def improve_tensor_precision(self, tensor_obj):
15
34
  if (
16
35
  isinstance(tensor_obj, torch.Tensor)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import abstractmethod
2
17
  from typing import Any
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
  from msprobe.pytorch.free_benchmark import logger
3
18
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -1,10 +1,26 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from abc import ABC, abstractmethod
3
18
  from typing import Any, Optional, Tuple
4
- import numpy as np
5
19
 
20
+ import numpy as np
6
21
  import torch
7
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.exceptions import FreeBenchmarkException
8
24
  from msprobe.pytorch.free_benchmark import logger
9
25
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
10
26
  from msprobe.pytorch.free_benchmark.common.enums import (
@@ -35,7 +51,9 @@ class FuzzHandler(ABC):
35
51
  origin_ouput = origin_ouput.values
36
52
  perturbed_output = perturbed_output.values
37
53
  if hasattr(perturbed_output, "dtype"):
38
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
54
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
55
+ perturbed_output.dtype, FuzzThreshold.F32_THD
56
+ )
39
57
  else:
40
58
  abs_tol = FuzzThreshold.F32_THD
41
59
  return (
@@ -53,16 +71,22 @@ class FuzzHandler(ABC):
53
71
  :return origin_output_chunks: 切块后原始输出列表
54
72
  :return perturbed_output_chunks: 切块后扰动后输出列表
55
73
  """
56
- single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
74
+ single_output_mem = (
75
+ origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
76
+ )
57
77
  if single_output_mem == 0 or origin_output.ndim == 0:
58
78
  return [origin_output], [perturbed_output]
59
79
  # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
60
80
  chunks_exp = int(math.log(single_output_mem, 2)) - 4
61
- chunks = 2 ** chunks_exp
81
+ chunks = 2**chunks_exp
62
82
  chunks = max(chunks, 1)
63
83
  chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
64
- origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
65
- perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
84
+ origin_output_chunks = TorchC.tensor_split(
85
+ TorchC.reshape(origin_output, (-1,)), chunks
86
+ )
87
+ perturbed_output_chunks = TorchC.tensor_split(
88
+ TorchC.reshape(perturbed_output, (-1,)), chunks
89
+ )
66
90
  return origin_output_chunks, perturbed_output_chunks
67
91
 
68
92
  @staticmethod
@@ -80,14 +104,24 @@ class FuzzHandler(ABC):
80
104
  pass
81
105
 
82
106
  def get_ratio_from_specific_norm(
83
- self, origin_output, perturbed_output, norm_type, abs_tol
107
+ self, origin_output, perturbed_output, norm_type, abs_tol
84
108
  ):
85
109
  if norm_type == NormType.ENDLESS_NORM:
86
110
  return self.calculate_error(origin_output, perturbed_output, abs_tol)
87
111
  return ThresholdConfig.COMP_CONSISTENT
88
112
 
89
113
  def calculate_error(self, origin_output, perturbed_output, abs_tol):
90
- origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
114
+ origin_output_chunks, perturbed_output_chunks = (
115
+ self.tensor_split_for_error_calculate(origin_output, perturbed_output)
116
+ )
117
+ if len(origin_output_chunks) != len(perturbed_output_chunks):
118
+ err_msg = (
119
+ f"For {self.params.api_name}, the number of compare tensor chunks is different: "
120
+ f"{len(origin_output_chunks)} != {len(perturbed_output_chunks)}. please check!"
121
+ )
122
+ raise FreeBenchmarkException(
123
+ FreeBenchmarkException.OutputIndexError, err_msg
124
+ )
91
125
  norm1 = -np.inf
92
126
  norm2 = -np.inf
93
127
  norm3 = np.inf
@@ -95,11 +129,25 @@ class FuzzHandler(ABC):
95
129
  if chunk_origin.nelement() == 0:
96
130
  break
97
131
  chunk_perturbed = perturbed_output_chunks[i]
98
- ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
99
- TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
100
- ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
101
- TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
102
- norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
132
+ ratio_tensor1 = TorchC.where(
133
+ TorchC.abs(chunk_perturbed) > abs_tol,
134
+ TorchC.div(
135
+ TorchC.clamp(chunk_origin, min=abs_tol),
136
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
137
+ ),
138
+ 1,
139
+ )
140
+ ratio_tensor2 = TorchC.where(
141
+ TorchC.abs(chunk_origin) > abs_tol,
142
+ TorchC.div(
143
+ TorchC.clamp(chunk_perturbed, min=abs_tol),
144
+ TorchC.clamp(chunk_origin, min=abs_tol),
145
+ ),
146
+ 1,
147
+ )
148
+ norm_values = TorchC.stack(
149
+ [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
150
+ )
103
151
  max_ratio1, max_ratio2 = norm_values.tolist()
104
152
  norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
105
153
  norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
@@ -126,13 +174,13 @@ class FuzzHandler(ABC):
126
174
  if self.params.fuzz_stage == Const.BACKWARD:
127
175
  abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
128
176
  else:
129
- abs_tol = abs_tol ** 0.5
177
+ abs_tol = abs_tol**0.5
130
178
  return self.get_ratio_from_specific_norm(
131
179
  origin_output, perturbed_output, norm_type, abs_tol
132
180
  )
133
181
 
134
182
  def npu_compare(
135
- self, origin_output, perturbed_output
183
+ self, origin_output, perturbed_output
136
184
  ) -> Tuple[bool, Optional[float]]:
137
185
 
138
186
  if isinstance(perturbed_output, int):
@@ -150,6 +198,7 @@ class FuzzHandler(ABC):
150
198
  f"[msprobe] Free Benchmark: For {self.params.api_name} "
151
199
  f"The compare for output type {type(perturbed_output)} is not supported"
152
200
  )
201
+ return True, 1
153
202
 
154
203
  threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
155
204
  ratio = self.ratio_calculate(
@@ -189,7 +238,7 @@ class FuzzHandler(ABC):
189
238
  max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
190
239
  )
191
240
  data_params.is_consistent = (
192
- is_consistent and data_params.is_consistent
241
+ is_consistent and data_params.is_consistent
193
242
  )
194
243
  if not is_consistent and data_params.grad_unequal_flag:
195
244
  self.unequal_rows.append(
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from msprobe.pytorch.free_benchmark import logger