mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,9 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
18
+ from msprobe.core.common.exceptions import FreeBenchmarkException
19
+ from msprobe.pytorch.free_benchmark import logger
3
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
4
21
  from msprobe.pytorch.free_benchmark.common.utils import Tools
5
22
  from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
- from msprobe.pytorch.free_benchmark import logger
7
23
 
8
24
 
9
25
  class FixHandler(FuzzHandler):
@@ -16,9 +32,9 @@ class FixHandler(FuzzHandler):
16
32
  return Tools.convert_fuzz_output_to_origin(
17
33
  data_params.original_result, data_params.perturbed_result
18
34
  )
19
- except Exception as e:
20
- logger.warning_on_rank_0(
35
+ except FreeBenchmarkException as e:
36
+ logger.warning(
21
37
  f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
- f"Fix output failed. "
38
+ f"Fix output failed because of: \n{e}"
23
39
  )
24
- return data_params.original_result
40
+ return data_params.original_result
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
17
  from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
18
  from msprobe.pytorch.free_benchmark.common.enums import HandlerType
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from typing import Any
3
18
 
@@ -118,8 +133,10 @@ class PreheatHandler(FuzzHandler):
118
133
  """
119
134
  # 每一步样本数
120
135
  total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
- sample_count_per_step = self._get_sample_count_per_step()
122
136
  need_sample_set = set()
137
+ if total_count == 0:
138
+ return need_sample_set
139
+ sample_count_per_step = self._get_sample_count_per_step()
123
140
  prehead_step = self.params.preheat_config.get("preheat_step")
124
141
  for i in range(1, sample_count_per_step + 1):
125
142
  count = (prehead_step * (i - 1) + self.params.step) % total_count
@@ -136,9 +153,7 @@ class PreheatHandler(FuzzHandler):
136
153
 
137
154
  def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
155
  con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
- incon_ratio = [
140
- ratio for ratio, is_consistent in compare_result if not is_consistent
141
- ]
156
+ incon_ratio = [ratio for ratio, is_consistent in compare_result if not is_consistent]
142
157
  old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
158
  new_thd = old_thd
144
159
  # 正例负例都存在
@@ -1,4 +1,18 @@
1
- from msprobe.pytorch.common.utils import logger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
3
17
  from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
4
18
  npu_confusion_transpose_backward
@@ -12,7 +26,8 @@ from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_
12
26
  from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
13
27
  from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
14
28
  npu_scaled_masked_softmax_backward
15
- from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
29
+ from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
30
+ from msprobe.pytorch.common.utils import logger
16
31
 
17
32
 
18
33
  class Register(dict):
@@ -0,0 +1,84 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.core.data_dump.scope import BaseScope
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
23
+ from msprobe.pytorch.hook_module.api_registry import api_register
24
+ from msprobe.pytorch.service import torch_version_above_or_equal_2
25
+
26
+ hook_handle_list = []
27
+
28
+
29
+ def module_dump(module, dump_name):
30
+ if not isinstance(module, nn.Module):
31
+ logger.error("The parameter module in module_dump must be a Module subclass.")
32
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
33
+ if not isinstance(dump_name, str):
34
+ logger.error("The parameter dump_name in module_dump must be a str type.")
35
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
36
+
37
+ api_register.api_originality()
38
+ register_hook(module, dump_name)
39
+
40
+
41
+ def module_dump_end():
42
+ api_register.api_modularity()
43
+ remove_hook()
44
+ hook_handle_list.clear()
45
+
46
+
47
+ def register_hook(module, dump_name):
48
+ prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP
49
+
50
+ pdg = PrecisionDebugger()
51
+ _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \
52
+ pdg.service.build_hook(BaseScope.Module_Type_Module, prefix)
53
+
54
+ if torch_version_above_or_equal_2:
55
+ forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
56
+ hook_handle_list.append(forward_hook_handle)
57
+ else:
58
+ pdg.service.check_register_full_backward_hook(module)
59
+ full_backward_hook_handle = module.register_full_backward_hook(
60
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
61
+ forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
62
+ hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle])
63
+ pdg.service.check_register_full_backward_hook(module)
64
+ full_backward_hook_handle = module.register_full_backward_hook(backward_hook)
65
+
66
+ forward_pre_hook_handle = module.register_forward_pre_hook(
67
+ pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
68
+ forward_hook_handle = module.register_forward_hook(
69
+ pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
70
+ hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle])
71
+
72
+ if torch_version_above_or_equal_2:
73
+ backward_pre_hook_handle = module.register_full_backward_pre_hook(
74
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
75
+ pdg.service.check_register_full_backward_hook(module)
76
+ full_backward_hook_handle = module.register_full_backward_hook(
77
+ pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
78
+ hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle])
79
+
80
+
81
+ def remove_hook():
82
+ for hook_handle in hook_handle_list:
83
+ if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
84
+ hook_handle.remove()
@@ -1,15 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from collections import defaultdict
3
18
 
4
19
  import torch
5
- if int(torch.__version__.split('.')[0]) >= 2:
6
- from torch.optim.optimizer import register_optimizer_step_pre_hook
7
- from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
8
- from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
20
+ from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
9
21
  from msprobe.core.grad_probe.constant import level_adp
22
+ from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
10
23
  from msprobe.pytorch.common.log import logger
11
- from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
12
24
  from msprobe.pytorch.common.utils import get_rank_id, print_rank_0
25
+ from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
26
+
27
+ if int(torch.__version__.split('.')[0]) >= 2:
28
+ from torch.optim.optimizer import register_optimizer_step_pre_hook
13
29
 
14
30
 
15
31
  class GradientMonitor:
@@ -75,7 +91,7 @@ class GradientMonitor:
75
91
  output_lines.append(grad_info)
76
92
  if self._level_adp["have_grad_direction"]:
77
93
  GradientMonitor.save_grad_direction(param_name, grad,
78
- f'{self._output_path}/rank{self._rank}/step{self._step}')
94
+ f'{self._output_path}/rank{self._rank}/step{self._step}')
79
95
  output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
80
96
  if not os.path.isdir(output_dirpath):
81
97
  create_directory(output_dirpath)
@@ -87,5 +103,6 @@ class GradientMonitor:
87
103
  output_lines.insert(0, header_result)
88
104
  write_csv(output_lines, output_path)
89
105
  logger.info(f"write grad data to {output_path}")
106
+
90
107
  if int(torch.__version__.split('.')[0]) >= 2:
91
108
  register_optimizer_step_pre_hook(optimizer_pre_step_hook)
@@ -1,11 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
  from collections import namedtuple
3
18
  import hashlib
19
+ from functools import wraps
4
20
  import torch
5
21
  from msprobe.core.grad_probe.constant import GradConst
6
22
 
7
- CSV_header_input = namedtuple("CSV_header_input", ["bounds"])
8
- CSV_content_input = namedtuple("CSV_content_input", ["grad", "bounds"])
23
+ CsvHeaderInput = namedtuple("CsvHeaderInput", ["bounds"])
24
+ CsvContentInput = namedtuple("CsvContentInput", ["grad", "bounds"])
9
25
 
10
26
 
11
27
  class GradStatCsv:
@@ -15,7 +31,7 @@ class GradStatCsv:
15
31
  def generate_csv_header(level, bounds):
16
32
  header = ["param_name"]
17
33
  for key in level["header"]:
18
- csv_header_input = CSV_header_input(bounds=bounds)
34
+ csv_header_input = CsvHeaderInput(bounds=bounds)
19
35
  header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input))
20
36
  return header
21
37
 
@@ -23,7 +39,7 @@ class GradStatCsv:
23
39
  def generate_csv_line(param_name, level, grad, bounds):
24
40
  line = [param_name]
25
41
  for key in level["header"]:
26
- csv_content_input = CSV_content_input(grad=grad, bounds=bounds)
42
+ csv_content_input = CsvContentInput(grad=grad, bounds=bounds)
27
43
  line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input))
28
44
  return line
29
45
 
@@ -37,20 +53,24 @@ def register_csv_item(key, cls=None):
37
53
 
38
54
 
39
55
  class CsvItem(ABC):
56
+ @staticmethod
40
57
  @abstractmethod
41
58
  def generate_csv_header(csv_header_input):
42
59
  pass
43
60
 
61
+ @staticmethod
44
62
  @abstractmethod
45
63
  def generate_csv_content(csv_content_input):
46
64
  pass
47
65
 
48
66
 
49
67
  @register_csv_item(GradConst.MD5)
50
- class CSV_md5(CsvItem):
68
+ class CsvMd5(CsvItem):
69
+ @staticmethod
51
70
  def generate_csv_header(csv_header_input):
52
71
  return ["MD5"]
53
72
 
73
+ @staticmethod
54
74
  def generate_csv_content(csv_content_input):
55
75
  grad = csv_content_input.grad
56
76
  tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
@@ -59,7 +79,8 @@ class CSV_md5(CsvItem):
59
79
 
60
80
 
61
81
  @register_csv_item(GradConst.DISTRIBUTION)
62
- class CSV_distribution(CsvItem):
82
+ class CsvDistribution(CsvItem):
83
+ @staticmethod
63
84
  def generate_csv_header(csv_header_input):
64
85
  bounds = csv_header_input.bounds
65
86
  intervals = []
@@ -73,6 +94,7 @@ class CSV_distribution(CsvItem):
73
94
 
74
95
  return intervals
75
96
 
97
+ @staticmethod
76
98
  def generate_csv_content(csv_content_input):
77
99
  grad = csv_content_input.grad
78
100
  bounds = csv_content_input.bounds
@@ -90,40 +112,48 @@ class CSV_distribution(CsvItem):
90
112
 
91
113
 
92
114
  @register_csv_item(GradConst.MAX)
93
- class CSV_max(CsvItem):
115
+ class CsvMax(CsvItem):
116
+ @staticmethod
94
117
  def generate_csv_header(csv_header_input):
95
118
  return ["max"]
96
119
 
120
+ @staticmethod
97
121
  def generate_csv_content(csv_content_input):
98
122
  grad = csv_content_input.grad
99
123
  return [torch.max(grad).cpu().detach().float().numpy().tolist()]
100
124
 
101
125
 
102
126
  @register_csv_item(GradConst.MIN)
103
- class CSV_max(CsvItem):
127
+ class CsvMin(CsvItem):
128
+ @staticmethod
104
129
  def generate_csv_header(csv_header_input):
105
130
  return ["min"]
106
131
 
132
+ @staticmethod
107
133
  def generate_csv_content(csv_content_input):
108
134
  grad = csv_content_input.grad
109
135
  return [torch.min(grad).cpu().detach().float().numpy().tolist()]
110
136
 
111
137
 
112
138
  @register_csv_item(GradConst.NORM)
113
- class CSV_max(CsvItem):
139
+ class CsvNorm(CsvItem):
140
+ @staticmethod
114
141
  def generate_csv_header(csv_header_input):
115
142
  return ["norm"]
116
143
 
144
+ @staticmethod
117
145
  def generate_csv_content(csv_content_input):
118
146
  grad = csv_content_input.grad
119
147
  return [torch.norm(grad).cpu().detach().float().numpy().tolist()]
120
148
 
121
149
 
122
150
  @register_csv_item(GradConst.SHAPE)
123
- class CSV_shape(CsvItem):
151
+ class CsvShape(CsvItem):
152
+ @staticmethod
124
153
  def generate_csv_header(csv_header_input):
125
154
  return ["shape"]
126
155
 
156
+ @staticmethod
127
157
  def generate_csv_content(csv_content_input):
128
158
  grad = csv_content_input.grad
129
159
  return [list(grad.shape)]
@@ -1 +1,16 @@
1
- from .wrap_functional import remove_dropout
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .wrap_functional import remove_dropout
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import torch
19
17
  import torch.distributed as dist
@@ -107,7 +105,14 @@ class ApiRegistry:
107
105
  if not is_gpu:
108
106
  self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
109
107
 
110
- def initialize_hook(self, hook):
108
+ def initialize_hook(self, hook, online_run_ut=False):
109
+ """
110
+ initialize_hook
111
+ Args:
112
+ hook (_type_): initialize_hook
113
+ online_run_ut (bool): default False, whether online run_ut or not.
114
+ If online_run_ut is True, the hook will not wrap the aten ops.
115
+ """
111
116
  self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr)
112
117
  wrap_tensor.wrap_tensor_ops_and_bind(hook)
113
118
  for attr_name in dir(wrap_tensor.HOOKTensor):
@@ -137,7 +142,7 @@ class ApiRegistry:
137
142
  self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP,
138
143
  attr_name)
139
144
 
140
- if torch_version_above_2:
145
+ if torch_version_above_2 and not online_run_ut:
141
146
  self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr)
142
147
  wrap_aten.wrap_aten_ops_and_bind(hook)
143
148
  for attr_name in dir(wrap_aten.HOOKAtenOP):
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import functools
19
17
  import threading
@@ -58,12 +56,12 @@ class HOOKModule(nn.Module):
58
56
  self.register_forward_hook(forward_hook)
59
57
  self.register_backward_hook(backward_hook)
60
58
 
61
- def __call__(self, *input, **kwargs):
59
+ def __call__(self, *args, **kwargs):
62
60
  changed = False
63
61
  if not self.stop_hook:
64
62
  HOOKModule.inner_stop_hook[self.current_thread] = True
65
63
  changed = True
66
- result = self._call_func(*input, **kwargs)
64
+ result = self._call_func(*args, **kwargs)
67
65
  if changed:
68
66
  HOOKModule.inner_stop_hook[self.current_thread] = False
69
67
  return result
@@ -72,28 +70,28 @@ class HOOKModule(nn.Module):
72
70
  def reset_module_stats(cls):
73
71
  cls.module_count = {}
74
72
 
75
- def _call_func(self, *input, **kwargs):
73
+ def _call_func(self, *args, **kwargs):
76
74
  full_backward_hooks, non_full_backward_hooks = [], []
77
75
  if len(self._backward_hooks) > 0:
78
76
  full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
79
77
  for hook in self._forward_pre_hooks.values():
80
- result_input, result_kwargs = hook(self, input, kwargs)
81
- if result_input is not None:
82
- if not isinstance(result_input, tuple):
83
- result_input = (result_input,)
84
- input = result_input
78
+ result_args, result_kwargs = hook(self, args, kwargs)
79
+ if result_args is not None:
80
+ if not isinstance(result_args, tuple):
81
+ result_args = (result_args,)
82
+ args = result_args
85
83
  if result_kwargs is not None:
86
84
  kwargs = result_kwargs
87
85
  bw_hook = None
88
86
  if len(full_backward_hooks) > 0:
89
87
  bw_hook = full_hooks.BackwardHook(self, full_backward_hooks)
90
- input = bw_hook.setup_input_hook(input)
88
+ args = bw_hook.setup_input_hook(args)
91
89
  if torch._C._get_tracing_state():
92
- result = self._slow_forward(*input, **kwargs)
90
+ result = self._slow_forward(*args, **kwargs)
93
91
  else:
94
- result = self.forward(*input, **kwargs)
92
+ result = self.forward(*args, **kwargs)
95
93
  for hook in self._forward_hooks.values():
96
- hook_result = hook(self, input, kwargs, result)
94
+ hook_result = hook(self, args, kwargs, result)
97
95
  if hook_result is not None:
98
96
  result = hook_result
99
97
  if bw_hook:
@@ -116,5 +114,5 @@ class HOOKModule(nn.Module):
116
114
  wrapper = functools.partial(hook, self)
117
115
  functools.update_wrapper(wrapper, hook)
118
116
  grad_fn.register_hook(wrapper)
119
- self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
117
+ self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
120
118
  return result
@@ -1130,6 +1130,7 @@ torch_npu:
1130
1130
  - npu_prompt_flash_attention
1131
1131
  - npu_lstm
1132
1132
  - npu_apply_adam
1133
+ - npu_apply_adam_w
1133
1134
 
1134
1135
  aten:
1135
1136
  - signbit
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  from msprobe.core.common.file_utils import load_yaml
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -24,6 +22,7 @@ from msprobe.core.common.const import Const
24
22
  from msprobe.core.common.file_utils import load_yaml
25
23
  from msprobe.pytorch.function_factory import npu_custom_grad_functions
26
24
 
25
+
27
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
28
27
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
28
  ops = load_yaml(yaml_path)
@@ -50,6 +49,8 @@ class AtenOPTemplate(HOOKModule):
50
49
  def __init__(self, op, hook, need_hook=True):
51
50
  if isinstance(op, torch._ops.OpOverloadPacket):
52
51
  op_name_ = op._qualified_op_name.split("::")[-1]
52
+ elif isinstance(op, str):
53
+ op_name_ = str(op)
53
54
  else:
54
55
  op_name_ = op.name().split("::")[-1]
55
56
  overload_name = op._overloadname
@@ -76,13 +77,13 @@ class AtenOPTemplate(HOOKModule):
76
77
 
77
78
 
78
79
  class AtenOPPacketTemplate():
79
- def __init__(self, opPacket, hook):
80
- self.opPacket = opPacket
80
+ def __init__(self, op_packet, hook):
81
+ self.op_packet = op_packet
81
82
  self.hook = hook
82
83
 
83
84
  def __getattr__(self, key):
84
85
  try:
85
- attr = getattr(self.opPacket, key)
86
+ attr = getattr(self.op_packet, key)
86
87
  except AttributeError as e:
87
88
  raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e
88
89
  if isinstance(attr, torch._ops.OpOverload):
@@ -92,10 +93,10 @@ class AtenOPPacketTemplate():
92
93
 
93
94
  @torch_device_guard
94
95
  def __call__(self, *args, **kwargs):
95
- return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs)
96
+ return AtenOPTemplate(self.op_packet, self.hook)(*args, **kwargs)
96
97
 
97
98
  def overloads(self):
98
- return self.opPacket.overloads()
99
+ return self.op_packet.overloads()
99
100
 
100
101
 
101
102
  def wrap_aten_op(op, hook):