mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.data_dump.scope import BaseScope
19
+ from msprobe.pytorch.common.log import logger
20
+ from msprobe.pytorch.hook_module.api_registry import api_register
21
+
22
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
+
24
+
25
+ class ModuleDumper:
26
+ def __init__(self, service):
27
+ self.service = service
28
+ self.hook_handle_list = []
29
+
30
+ def start_module_dump(self, module, dump_name):
31
+ api_register.api_originality()
32
+ self.register_hook(module, dump_name)
33
+
34
+ def stop_module_dump(self):
35
+ api_register.api_modularity()
36
+ for hook_handle in self.hook_handle_list:
37
+ if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
38
+ hook_handle.remove()
39
+ self.hook_handle_list.clear()
40
+
41
+ def register_hook(self, module, dump_name):
42
+ prefix_name = (
43
+ BaseScope.Module_Type_Module + Const.SEP +
44
+ dump_name + Const.SEP +
45
+ module.__class__.__name__ + Const.SEP
46
+ )
47
+ module_processor = self.service.module_processor
48
+ _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
49
+ BaseScope.Module_Type_Module,
50
+ prefix_name
51
+ )
52
+
53
+ if module_processor.has_register_backward_hook(module):
54
+ logger.warning(
55
+ f"The {dump_name} module has registered deprecated register_backward_hook,"
56
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
57
+ )
58
+ if torch_version_above_or_equal_2:
59
+ forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
60
+ else:
61
+ if not module_processor.has_register_backward_hook(module):
62
+ backward_hook_handle = module.register_full_backward_hook(
63
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
64
+ )
65
+ self.hook_handle_list.append(backward_hook_handle)
66
+ forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
67
+ self.hook_handle_list.append(forward_hook_handle)
68
+ if not module_processor.has_register_backward_hook(module):
69
+ backward_hook_handle = module.register_full_backward_hook(backward_hook)
70
+ self.hook_handle_list.append(backward_hook_handle)
71
+
72
+ forward_pre_hook_handle = module.register_forward_pre_hook(
73
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
74
+ )
75
+ forward_hook_handle = module.register_forward_hook(
76
+ module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
77
+ )
78
+ self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
79
+ if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
80
+ backward_pre_hook_handle = module.register_full_backward_pre_hook(
81
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
82
+ )
83
+ backward_hook_handle = module.register_full_backward_hook(
84
+ module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
85
+ )
86
+ self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,12 +17,24 @@ from functools import wraps
17
17
 
18
18
  import torch
19
19
  from msprobe.core.common.const import Const
20
- from msprobe.core.data_dump.scope import ModuleRangeScope
20
+ from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
21
+ from msprobe.pytorch.common.log import logger
22
+ from torch.utils.checkpoint import checkpoint as origin_checkpoint
23
+ from torch.utils.checkpoint import set_checkpoint_early_stop
21
24
  from torch.utils.hooks import BackwardHook
22
25
 
23
26
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
27
 
25
28
 
29
+ def checkpoint_without_early_stop(*args, **kwargs):
30
+ with set_checkpoint_early_stop(False):
31
+ return origin_checkpoint(*args, **kwargs)
32
+
33
+
34
+ def replace_checkpoint():
35
+ torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
36
+
37
+
26
38
  class ModuleProcesser:
27
39
  module_count = {}
28
40
  module_stack = []
@@ -30,13 +42,11 @@ class ModuleProcesser:
30
42
  module_node = {}
31
43
 
32
44
  def __init__(self, scope):
33
- if isinstance(scope, ModuleRangeScope):
34
- self.scope = scope
35
- else:
36
- self.scope = None
45
+ self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
37
46
  BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
38
47
  BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
39
48
  BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
49
+ replace_checkpoint()
40
50
 
41
51
  @staticmethod
42
52
  def filter_tensor_and_tuple(func):
@@ -66,7 +76,7 @@ class ModuleProcesser:
66
76
  return ModuleProcesser.clone_if_tensor(result)
67
77
 
68
78
  return clone_return_value_func
69
-
79
+
70
80
  @staticmethod
71
81
  def clone_if_tensor(result):
72
82
  if isinstance(result, torch.Tensor):
@@ -88,6 +98,22 @@ class ModuleProcesser:
88
98
  ModuleProcesser.module_count[module_name] += 1
89
99
  return ModuleProcesser.module_count[module_name]
90
100
 
101
+ @staticmethod
102
+ def has_register_backward_hook(module):
103
+ return hasattr(module, '_backward_hooks') and \
104
+ len(module._backward_hooks) > 0 and \
105
+ module._is_full_backward_hook is False
106
+
107
+ @staticmethod
108
+ def get_modules_and_names(models):
109
+ modules_and_names_with_index = {}
110
+ if isinstance(models, (list, tuple)):
111
+ for index, model in enumerate(models):
112
+ modules_and_names_with_index[str(index)] = model.named_modules()
113
+ else:
114
+ modules_and_names_with_index["-1"] = models.named_modules()
115
+ return modules_and_names_with_index
116
+
91
117
  @classmethod
92
118
  def reset_module_stats(cls):
93
119
  cls.module_count = {}
@@ -95,6 +121,42 @@ class ModuleProcesser:
95
121
  cls.api_parent_node = ""
96
122
  cls.module_node = {}
97
123
 
124
+ def register_module_hook(self, models, build_hook):
125
+ logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
126
+ modules_and_names_with_index = self.get_modules_and_names(models)
127
+ for index, modules_and_names in modules_and_names_with_index.items():
128
+ model = models if index == "-1" else models[int(index)]
129
+ for name, module in modules_and_names:
130
+ if module == model:
131
+ continue
132
+ module_index = (index + Const.SEP) if index != "-1" else ""
133
+ prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
134
+ name + Const.SEP + module.__class__.__name__ + Const.SEP)
135
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
136
+ BaseScope.Module_Type_Module,
137
+ prefix_name
138
+ )
139
+
140
+ if self.has_register_backward_hook(module):
141
+ logger.warning(
142
+ f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
143
+ f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
144
+ )
145
+ if torch_version_above_or_equal_2:
146
+ module.register_forward_hook(forward_hook, with_kwargs=True)
147
+ else:
148
+ if not self.has_register_backward_hook(module):
149
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
150
+ module.register_forward_hook(forward_hook_torch_version_below_2)
151
+ if not self.has_register_backward_hook(module):
152
+ module.register_full_backward_hook(backward_hook)
153
+
154
+ module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
155
+ module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
156
+ if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
157
+ module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
158
+ module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
159
+
98
160
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
99
161
 
100
162
  def pre_hook(module, input, output=None):
@@ -103,7 +165,10 @@ class ModuleProcesser:
103
165
  except IndexError as e:
104
166
  index = None
105
167
  pass
106
- module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
168
+ full_name = name_prefix + Const.SEP + str(index)
169
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
170
+ module.mindstudio_reserved_name = []
171
+ module.mindstudio_reserved_name.append(full_name)
107
172
  if self.module_stack:
108
173
  ModuleProcesser.module_node[full_name] = self.module_stack[-1]
109
174
  else:
@@ -122,8 +187,11 @@ class ModuleProcesser:
122
187
  ModuleProcesser.api_parent_node = self.module_stack[-1]
123
188
  else:
124
189
  ModuleProcesser.api_parent_node = None
190
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
191
+ raise RuntimeError(f"module reserve name is None when pop")
192
+ current_name = module.mindstudio_reserved_name.pop()
125
193
  if self.scope:
126
- self.scope.end_module(module.mindstudio_reserved_name)
194
+ self.scope.end_module(current_name)
127
195
 
128
196
  def backward_hook(module, input, output=None):
129
197
  try:
@@ -131,7 +199,10 @@ class ModuleProcesser:
131
199
  except IndexError as e:
132
200
  index = None
133
201
  pass
134
- module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
202
+ full_name = name_prefix + Const.SEP + str(index)
203
+ if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
204
+ module.mindstudio_reserved_name = []
205
+ module.mindstudio_reserved_name.append(full_name)
135
206
  forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
136
207
  ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
137
208
  Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Dict
2
17
 
3
18
  import numpy as np
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from collections import defaultdict
2
17
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
 
3
18
 
@@ -17,6 +17,7 @@ from dataclasses import dataclass
17
17
  from typing import Any, Callable, Dict, List, Optional, Tuple
18
18
 
19
19
  import torch
20
+ from msprobe.core.common.exceptions import FreeBenchmarkException
20
21
  from msprobe.pytorch.free_benchmark import logger
21
22
  from msprobe.pytorch.free_benchmark.common.enums import (
22
23
  DeviceType,
@@ -38,7 +39,6 @@ class DataParams:
38
39
  origin_func: Optional[Callable] = None
39
40
  api_type: Optional[str] = None
40
41
  fuzz_stage: Optional[str] = None
41
- grad_unequal_flag: Optional[bool] = True
42
42
 
43
43
 
44
44
  @dataclass
@@ -126,9 +126,17 @@ def make_unequal_row(
126
126
  )
127
127
  if isinstance(ratio, float):
128
128
  row.max_rel = ratio - 1
129
+ if isinstance(ratio, str):
130
+ row.max_rel = ratio
129
131
  origin_tensor = data_params.original_result
130
132
  perturbed_tensor = data_params.perturbed_result
131
- if index:
133
+ if index is not None:
134
+ if index >= len(origin_tensor) or index >= len(perturbed_tensor):
135
+ err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
136
+ raise FreeBenchmarkException(
137
+ FreeBenchmarkException.OutputIndexError,
138
+ error_info=err_msg,
139
+ )
132
140
  origin_tensor = origin_tensor[index]
133
141
  perturbed_tensor = perturbed_tensor[index]
134
142
  row.output_index = index
@@ -13,7 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+
16
17
  import torch
18
+ from msprobe.core.common.exceptions import FreeBenchmarkException
19
+ from msprobe.core.common.utils import recursion_depth_decorator
17
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
18
21
 
19
22
 
@@ -51,6 +54,7 @@ class Tools:
51
54
  return api_name.rsplit(".", 2)[0]
52
55
 
53
56
  @staticmethod
57
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
54
58
  def convert_device_and_dtype(
55
59
  tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
56
60
  ):
@@ -73,23 +77,41 @@ class Tools:
73
77
  return tensor_seq
74
78
 
75
79
  @staticmethod
80
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
76
81
  def convert_fuzz_output_to_origin(origin, perturbed):
77
- if isinstance(origin, torch.Tensor):
82
+ if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
78
83
  origin.data = perturbed.to(origin.dtype).to(origin.device)
79
84
  return origin
80
- if isinstance(origin, dict):
85
+ if isinstance(origin, dict) and isinstance(perturbed, dict):
81
86
  output = dict()
82
87
  for key, value in origin.items():
88
+ if key not in perturbed:
89
+ err_msg = f"'{key}' not in perturbed output."
90
+ raise FreeBenchmarkException(
91
+ FreeBenchmarkException.InvalidPerturbedOutput,
92
+ error_info=err_msg,
93
+ )
83
94
  output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
84
95
  return output
85
- if isinstance(origin, (tuple, list)):
96
+ if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
86
97
  result = list()
98
+ if len(perturbed) != len(origin):
99
+ err_msg = (
100
+ f"length of perturbed output ({len(perturbed)}) is different "
101
+ f"from the length of original output ({len(origin)})."
102
+ )
103
+ raise FreeBenchmarkException(
104
+ FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
105
+ )
87
106
  for index_, value in enumerate(origin):
88
107
  result.append(
89
108
  Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
90
109
  )
91
110
  return type(origin)(result)
92
- return origin
111
+ err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
112
+ raise FreeBenchmarkException(
113
+ FreeBenchmarkException.UnsupportedType, error_info=err_msg
114
+ )
93
115
 
94
116
 
95
117
  class TorchC:
@@ -102,6 +124,7 @@ class TorchC:
102
124
  abs = torch._C._VariableFunctionsClass.abs
103
125
  where = torch._C._VariableFunctionsClass.where
104
126
  div = torch._C._VariableFunctionsClass.div
127
+ mul = torch._C._VariableFunctionsClass.mul
105
128
  max = torch._C._VariableFunctionsClass.max
106
129
  min = torch._C._VariableFunctionsClass.min
107
130
  gt = torch._C._VariableFunctionsClass.gt
@@ -116,3 +139,5 @@ class TorchC:
116
139
  tensor_split = torch._C._VariableFunctionsClass.tensor_split
117
140
  stack = torch._C._VariableFunctionsClass.stack
118
141
  reshape = torch._C._VariableFunctionsClass.reshape
142
+ nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
143
+ aminmax = torch._C._VariableFunctionsClass.aminmax
@@ -82,13 +82,11 @@ class GradSaver:
82
82
  data_params = DataParams()
83
83
  data_params.original_result = origin_grad
84
84
  data_params.perturbed_result = perturbed_grad
85
- data_params.grad_unequal_flag = False
86
85
  data_params.valid_input_index = index
87
86
  try:
88
87
  handler.handle(data_params)
89
88
  if not data_params.is_consistent:
90
89
  self.is_compare = False
91
- data_params.grad_unequal_flag = True
92
90
  data_params.is_consistent = True
93
91
  data_params.perturbed_result = self.perturbed_grad_input
94
92
  data_params.original_result = self.origin_grad_input
@@ -102,8 +100,13 @@ class GradSaver:
102
100
  def check_grad_input(self, origin_grad, new_grad_index):
103
101
  if self.perturbed_grad_input is None:
104
102
  raise FreeBenchmarkException(
105
- FreeBenchmarkException.InvalidGrad,
106
- f"grad not exists : {self.api_name}.",
103
+ FreeBenchmarkException.InvalidPerturbedOutput,
104
+ f"perturbed grad not exists for {self.api_name}.",
105
+ )
106
+ if len(self.perturbed_grad_input) <= new_grad_index:
107
+ raise FreeBenchmarkException(
108
+ FreeBenchmarkException.InvalidPerturbedOutput,
109
+ f"perturbed grad index {new_grad_index} is out of bounds for {self.api_name}.",
107
110
  )
108
111
  with torch.no_grad():
109
112
  perturbed_grad = self.perturbed_grad_input[new_grad_index].to(
@@ -111,7 +114,7 @@ class GradSaver:
111
114
  )
112
115
  if origin_grad.shape != perturbed_grad.shape:
113
116
  raise FreeBenchmarkException(
114
- FreeBenchmarkException.InvalidGrad,
117
+ FreeBenchmarkException.InvalidPerturbedOutput,
115
118
  f"grad shapes are inconsistent. api:{self.handler_params.api_name}."
116
119
  f"origin:{origin_grad.shape}, perturbation: {perturbed_grad.shape}",
117
120
  )
@@ -164,6 +167,18 @@ class GradSaver:
164
167
  index_ = 0
165
168
  for object_ in inner_args:
166
169
  if object_ is CommonField.HOLD_PLACE:
170
+ if index_ >= len(inputs):
171
+ err_msg = (
172
+ f"[msprobe] Free benchmark: When getting input from vjp, "
173
+ f" the input index ({index_}) is out of bounds ({len(inputs)})."
174
+ )
175
+ logger.error_log_with_exp(
176
+ err_msg,
177
+ FreeBenchmarkException(
178
+ FreeBenchmarkException.InvalidGrad,
179
+ error_info=err_msg,
180
+ ),
181
+ )
167
182
  _real_input.append(inputs[index_])
168
183
  index_ += 1
169
184
  else:
@@ -16,6 +16,7 @@
16
16
  import math
17
17
 
18
18
  import torch
19
+ from msprobe.core.common.utils import recursion_depth_decorator
19
20
  from msprobe.pytorch.free_benchmark import logger
20
21
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
21
22
  from msprobe.pytorch.free_benchmark.common.utils import TorchC
@@ -67,6 +68,7 @@ class SingleCompare:
67
68
  return False
68
69
  return True
69
70
 
71
+ @recursion_depth_decorator("FreeBenchmark: SingleCompare.compare_seq")
70
72
  def compare_seq(self, actual, golden):
71
73
  if isinstance(golden, torch.Tensor):
72
74
  return self.compare_tensor_seq(actual, golden)
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
17
18
  from msprobe.pytorch.free_benchmark import logger
18
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
19
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -26,6 +27,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
26
27
 
27
28
  class AddNoiseLayer(NpuBaseLayer):
28
29
 
30
+ @recursion_depth_decorator("FreeBenchmark: AddNoiseLayer.add_noise")
29
31
  def add_noise(self, tensor_obj):
30
32
  if isinstance(tensor_obj, torch.Tensor):
31
33
  self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
@@ -99,7 +101,7 @@ class AddNoiseLayer(NpuBaseLayer):
99
101
  if max_val < abs_tol:
100
102
  logger.warning_on_rank_0(
101
103
  f"[msprobe] Free Benchmark: For {self.api_name}, "
102
- f"Maximun value is less than the minimun threshold. Cancel add noise."
104
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
103
105
  )
104
106
  return False
105
107
  return True
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
17
18
  from msprobe.pytorch.free_benchmark import logger
18
19
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
19
20
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -31,6 +32,7 @@ class BitNoiseLayer(NpuBaseLayer):
31
32
  self.bit_tail: int = 1
32
33
  self.bit_type = None
33
34
 
35
+ @recursion_depth_decorator("FreeBenchmark: BitNoiseLayer.add_bit_noise")
34
36
  def add_bit_noise(self, tensor_obj):
35
37
  """
36
38
  对输入添加噪声
@@ -79,14 +81,14 @@ class BitNoiseLayer(NpuBaseLayer):
79
81
  判断是否需要添加扰动, bit翻转
80
82
  """
81
83
  if not self.bit_type:
82
- logger.info_on_rank_0(
84
+ logger.warning_on_rank_0(
83
85
  f"[msprobe] Free Benchmark: For {self.api_name}, "
84
86
  f"dtype unsupported. Cancel perturbation."
85
87
  )
86
88
  return False
87
89
  if tensor_obj.numel() == 0:
88
90
  logger.warning_on_rank_0(
89
- f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
91
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
90
92
  f" Cancel adding noise."
91
93
  )
92
94
  return False
@@ -102,9 +104,9 @@ class BitNoiseLayer(NpuBaseLayer):
102
104
  )
103
105
  max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
104
106
  if max_val < abs_tol:
105
- logger.info_on_rank_0(
107
+ logger.warning_on_rank_0(
106
108
  f"[msprobe] Free Benchmark: For {self.api_name}, "
107
- f"Maximun value is less than the minimun threshold. Cancel add noise."
109
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
108
110
  )
109
111
  return False
110
112
  return True
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import torch
17
+ from msprobe.core.common.utils import recursion_depth_decorator
17
18
  from msprobe.pytorch.free_benchmark import logger
18
19
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
19
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
@@ -29,6 +30,7 @@ class ChangeValueLayer(NpuBaseLayer):
29
30
  self.head: int = 0
30
31
  self.tail: int = -1
31
32
 
33
+ @recursion_depth_decorator("FreeBenchmark: ChangeValueLayer.change_value")
32
34
  def change_value(self, tensor_obj):
33
35
  """
34
36
  交换张量首尾
@@ -15,6 +15,7 @@
15
15
 
16
16
  import torch
17
17
  from msprobe.core.common.const import Const
18
+ from msprobe.core.common.utils import recursion_depth_decorator
18
19
  from msprobe.pytorch.free_benchmark import logger
19
20
  from msprobe.pytorch.free_benchmark.common.constant import CommonField
20
21
  from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
@@ -26,6 +27,9 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import
26
27
 
27
28
  class ImprovePrecisionLayer(NpuBaseLayer):
28
29
 
30
+ @recursion_depth_decorator(
31
+ "FreeBenchmark: ImprovePrecisionLayer.improve_tensor_precision"
32
+ )
29
33
  def improve_tensor_precision(self, tensor_obj):
30
34
  if (
31
35
  isinstance(tensor_obj, torch.Tensor)