mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,91 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import mindspore as ms
18
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
19
+ from msprobe.core.common.utils import Const, load_yaml
20
+
21
+
22
+ cur_path = os.path.dirname(os.path.realpath(__file__))
23
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
24
+
25
+
26
+ def load_ops_functions():
27
+ ops_func = {f: getattr(ms.ops, f) for f in dir(ms.ops)}
28
+ mint_ops_func = {f: getattr(ms.mint, f) for f in dir(ms.mint)}
29
+ mint_func_ops_func = {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)}
30
+ return ops_func, mint_ops_func, mint_func_ops_func
31
+
32
+
33
+ def get_functional_ops():
34
+ ops_func, mint_ops_func, mint_func_ops_func = load_ops_functions()
35
+ config = load_yaml(yaml_path)
36
+ wrap_functional = config.get("ops")
37
+ wrap_mint = config.get("mint.ops")
38
+ wrap_mint_functional = config.get("mint.nn.functional")
39
+ return (
40
+ set(wrap_functional) & set(ops_func.keys()),
41
+ set(wrap_mint) & set(mint_ops_func.keys()),
42
+ set(wrap_mint_functional) & set(mint_func_ops_func.keys())
43
+ )
44
+
45
+
46
+ class HOOKFunctionalOP(object):
47
+ pass
48
+
49
+
50
+ class HOOKMintOP(object):
51
+ pass
52
+
53
+
54
+ class HOOKMintNNFunctionalOP(object):
55
+ pass
56
+
57
+
58
+ class FunctionalOPTemplate(HOOKCell):
59
+ def __init__(self, op_name, op_dict, prefix, hook):
60
+ self.op_name = op_name
61
+ self.op_func = op_dict[op_name]
62
+ self.prefix_op_name_ = prefix + str(op_name.split(Const.SEP)[-1]) + Const.SEP
63
+ super().__init__(hook)
64
+
65
+ def construct(self, *args, **kwargs):
66
+ if self.op_name.startswith('dropout'):
67
+ return args[0] if args else kwargs.get('input')
68
+ return self.op_func(*args, **kwargs)
69
+
70
+
71
+ def wrap_functional_op(op_name, op_dict, prefix, hook):
72
+ def op_template(*args, **kwargs):
73
+ return FunctionalOPTemplate(op_name, op_dict, prefix, hook)(*args, **kwargs)
74
+ return op_template
75
+
76
+
77
+ def wrap_functional_ops_and_bind(ops, op_dict, prefix, hook, hook_class):
78
+ for op_name in ops:
79
+ if callable(op_dict[op_name]):
80
+ setattr(hook_class, Const.ATTR_NAME_PREFIX + op_name, wrap_functional_op(op_name, op_dict, prefix, hook))
81
+
82
+
83
+ def setup_hooks(hook):
84
+ functional_ops, mint_ops, mint_func_ops = get_functional_ops()
85
+ wrap_functional_ops_and_bind(
86
+ functional_ops, {f: getattr(ms.ops, f) for f in dir(ms.ops)}, "Functional.", hook, HOOKFunctionalOP)
87
+ wrap_functional_ops_and_bind(
88
+ mint_ops, {f: getattr(ms.mint, f) for f in dir(ms.mint)}, "Mint.", hook, HOOKMintOP)
89
+ wrap_functional_ops_and_bind(
90
+ mint_func_ops, {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)}, "MintFunctional.", hook, HOOKMintNNFunctionalOP)
91
+
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import mindspore as ms
18
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
19
+ from msprobe.core.common.utils import Const, load_yaml
20
+
21
+
22
+ cur_path = os.path.dirname(os.path.realpath(__file__))
23
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
24
+
25
+
26
+ TensorFunc = {}
27
+ for f in dir(ms.Tensor):
28
+ TensorFunc[f] = getattr(ms.Tensor, f)
29
+
30
+
31
+ def get_tensor_ops():
32
+ yaml_data = load_yaml(yaml_path)
33
+ wrap_tensor_ops = yaml_data.get('tensor')
34
+ _tensor_ops = dir(ms.Tensor)
35
+ return set(wrap_tensor_ops) & set(_tensor_ops)
36
+
37
+
38
+ class HOOKTensor(object):
39
+ pass
40
+
41
+
42
+ class TensorOPTemplate(HOOKCell):
43
+
44
+ def __init__(self, op_name, hook):
45
+ self.op_name_ = op_name
46
+ self.prefix_op_name_ = "Tensor." + str(op_name) + Const.SEP
47
+ super().__init__(hook)
48
+
49
+ def construct(self, *args, **kwargs):
50
+ return TensorFunc[str(self.op_name_)](*args, **kwargs)
51
+
52
+
53
+ def wrap_tensor_op(op_name, hook):
54
+ def tensor_op_template(*args, **kwargs):
55
+ return TensorOPTemplate(op_name, hook)(*args, **kwargs)
56
+ return tensor_op_template
57
+
58
+
59
+ def wrap_tensor_ops_and_bind(hook):
60
+ _tensor_ops = get_tensor_ops()
61
+ for op_name in _tensor_ops:
62
+ if callable(TensorFunc[op_name]):
63
+ setattr(HOOKTensor, Const.ATTR_NAME_PREFIX + str(op_name), wrap_tensor_op(op_name, hook))
@@ -0,0 +1,56 @@
1
+ import os
2
+ from mindspore.common.api import _MindsporeFunctionExecutor
3
+ from mindspore._c_expression import PyNativeExecutor_
4
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
5
+ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
6
+
7
+
8
+ def dump_jit(name, in_feat, out_feat, is_forward):
9
+ pid = os.getpid()
10
+ ori_args = str(name)
11
+ index = ori_args.find("<")
12
+ if index != 0 and index != -1:
13
+ result = ori_args[0:index]
14
+ else:
15
+ result = "JitFunction"
16
+ if is_forward:
17
+ name_template = "Jit." + result + ".forward"
18
+ else:
19
+ name_template = "Jit." + result + ".backward"
20
+ JitDump.data_collector.visit_and_clear_overflow_status(name_template)
21
+ if JitDump.data_collector:
22
+ module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
23
+ JitDump.data_collector.forward_data_collect(name_template, {}, pid, module_input_output)
24
+
25
+
26
+ class JitDump(_MindsporeFunctionExecutor):
27
+ dump_config = None
28
+ jit_enable = False
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+ self._executor = PyNativeExecutor_.get_instance()
32
+
33
+ def __call__(self, *args, **kwargs):
34
+ api_register.api_set_ori_func()
35
+ out = super().__call__(*args, **kwargs)
36
+ dump_jit(args[0], args[1], out, True)
37
+ JitDump.jit_enable = True
38
+ api_register.api_set_hook_func()
39
+ return out
40
+
41
+ @classmethod
42
+ def set_config(cls, value):
43
+ cls.dump_config = value
44
+
45
+ @classmethod
46
+ def set_data_collector(cls, value):
47
+ cls.data_collector = value
48
+
49
+ def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
50
+ if JitDump.jit_enable:
51
+ api_register.api_set_ori_func()
52
+ output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
53
+ if JitDump.jit_enable:
54
+ dump_jit(obj, args, output, False)
55
+ api_register.api_set_hook_func()
56
+ return output
@@ -0,0 +1,65 @@
1
+ import os
2
+ import json
3
+
4
+ from msprobe.core.common.utils import make_dump_path_if_not_exists
5
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
6
+ from msprobe.core.common.log import logger
7
+ from msprobe.core.common.file_check import FileOpen
8
+ from msprobe.core.common.const import Const
9
+
10
+
11
+ class KernelKbykDump:
12
+ COMMON_SETTINGS = "common_dump_settings"
13
+ E2E_SETTINGS = "e2e_dump_settings"
14
+
15
+ def __init__(self, config: DebuggerConfig):
16
+ self.dump_json = dict()
17
+ common_set = dict()
18
+ e2e_set = dict()
19
+
20
+ common_set = dict()
21
+ common_set["dump_mode"] = 0
22
+ common_set["path"] = ""
23
+ common_set["net_name"] = "Net"
24
+ common_set["iteration"] = "all"
25
+ common_set["saved_data"] = "statistic"
26
+ common_set["input_output"] = 0
27
+ common_set["kernels"] = []
28
+ common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
29
+ e2e_set = dict()
30
+ e2e_set["enable"] = True
31
+ e2e_set["trans_flag"] = True
32
+
33
+ if config.list:
34
+ common_set["dump_mode"] = 1
35
+ common_set["kernels"] = config.list
36
+ common_set["path"] = config.dump_path
37
+ if config.step:
38
+ step_str = ""
39
+ for s in config.step:
40
+ step_str += (str(s) + '|')
41
+ common_set["iteration"] = step_str[:-1]
42
+ if config.rank:
43
+ common_set["support_device"] = config.rank
44
+ if config.task == Const.TENSOR:
45
+ common_set["saved_data"] = Const.TENSOR
46
+ if len(config.data_mode) == 1:
47
+ if config.data_mode[0] == Const.INPUT:
48
+ common_set["input_output"] = 1
49
+ if config.data_mode[0] == Const.OUTPUT:
50
+ common_set["input_output"] = 2
51
+
52
+ self.dump_json[KernelKbykDump.COMMON_SETTINGS] = common_set
53
+ self.dump_json[KernelKbykDump.E2E_SETTINGS] = e2e_set
54
+
55
+ def handle(self):
56
+ json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
57
+ make_dump_path_if_not_exists(json_path)
58
+ json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
59
+ with FileOpen(json_path, 'w') as f:
60
+ json.dump(self.dump_json, f)
61
+ logger.info(json_path + " has been created.")
62
+
63
+ os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
64
+ if "MS_ACL_DUMP_CFG_PATH" in os.environ:
65
+ del os.environ["MS_ACL_DUMP_CFG_PATH"]
File without changes
@@ -0,0 +1,116 @@
1
+ import os
2
+ import inspect
3
+ import importlib
4
+
5
+ import mindspore as ms
6
+ from mindspore.communication import comm_func
7
+
8
+ from msprobe.core.common.utils import load_yaml
9
+ from msprobe.core.common.const import Const
10
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
11
+ from msprobe.mindspore.free_benchmark.common.config import Config
12
+ from msprobe.core.common.file_check import check_path_length
13
+ from msprobe.mindspore.common.log import logger
14
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
15
+ from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
16
+
17
+
18
+ class ApiPyNativeSelFCheck:
19
+ def __init__(self, config: DebuggerConfig):
20
+ Config.is_enable = True
21
+ Config.handler_type = config.handler_type
22
+ Config.pert_type = config.pert_type
23
+ Config.stage = config.stage
24
+ Config.dump_level = config.dump_level
25
+ Config.steps = config.step
26
+ Config.ranks = config.rank
27
+ Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
28
+ check_path_length(Config.dump_path)
29
+
30
+ self.api_list = config.list
31
+ all_api = get_supported_ops()
32
+ if not self.api_list:
33
+ self.api_list = all_api
34
+ else:
35
+ self.api_list = set(self.api_list) & all_api
36
+
37
+ def handle(self):
38
+ for api_name in self.api_list:
39
+ hijack(api_name)
40
+
41
+
42
+ def get_supported_ops():
43
+ supported_ops = []
44
+ cur_path = os.path.dirname(os.path.realpath(__file__))
45
+ yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
46
+
47
+ yaml_data = load_yaml(yaml_path)
48
+ for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
49
+ ops = yaml_data.get(k)
50
+ if ops:
51
+ ops = [v + i for i in ops]
52
+ supported_ops += ops
53
+
54
+ _all_functional_ops = []
55
+ ms_ops = dir(ms.ops)
56
+ ms_ops = [FreeBenchmarkConst.OPS_PREFIX + i for i in ms_ops]
57
+ _all_functional_ops += ms_ops
58
+
59
+ ms_tensor = dir(ms.Tensor)
60
+ ms_tensor = [FreeBenchmarkConst.Tensor_PREFIX + i for i in ms_tensor]
61
+ _all_functional_ops += ms_tensor
62
+
63
+ ms_mint = dir(ms.mint)
64
+ ms_mint = [FreeBenchmarkConst.MINT_PREFIX + i for i in ms_mint]
65
+ _all_functional_ops += ms_mint
66
+
67
+ ms_mint_nn_func = dir(ms.mint.nn.functional)
68
+ ms_mint_nn_func = [FreeBenchmarkConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
69
+ _all_functional_ops += ms_mint_nn_func
70
+
71
+ ms_communication = dir(comm_func)
72
+ ms_communication = [FreeBenchmarkConst.COMM_PREFIX + i for i in ms_communication]
73
+ _all_functional_ops += ms_communication
74
+
75
+ return set(supported_ops) & set(_all_functional_ops)
76
+
77
+
78
+ def get_decorate_func():
79
+ return decorate_forward_function
80
+
81
+
82
+ def is_func_support_decorate(orig_func):
83
+ return not inspect.isclass(orig_func) and callable(orig_func)
84
+
85
+
86
+ def get_wrapper_obj(orig_func, api_name):
87
+ if is_func_support_decorate(orig_func):
88
+ wrapped_obj = get_decorate_func()(orig_func, api_name)
89
+ else:
90
+ wrapped_obj = orig_func
91
+ return wrapped_obj
92
+
93
+
94
+ def get_module(api_name):
95
+ func_name_list = api_name.split(Const.SEP)
96
+ func_name = func_name_list[-1]
97
+ module_obj = importlib.import_module(func_name_list[0])
98
+ for i, module_name in enumerate(func_name_list[1:-1]):
99
+ if not hasattr(module_obj, module_name):
100
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
101
+ module_obj = getattr(module_obj, module_name)
102
+ orig_func = getattr(module_obj, func_name)
103
+
104
+ return module_obj, orig_func
105
+
106
+
107
+ def hijack(api_name):
108
+ if not api_name.strip():
109
+ return
110
+ try:
111
+ func_name = api_name.split(Const.SEP)[-1]
112
+ module_obj, origin_func = get_module(api_name)
113
+ wrapped_obj = get_wrapper_obj(origin_func, api_name)
114
+ setattr(module_obj, func_name, wrapped_obj)
115
+ except Exception as e:
116
+ logger.error(f"Failed decorator {api_name}: {e}")
File without changes
@@ -0,0 +1,12 @@
1
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
2
+
3
+
4
+ class Config:
5
+ is_enable: bool = False
6
+ handler_type = FreeBenchmarkConst.DEFAULT_HANDLER_TYPE
7
+ pert_type = FreeBenchmarkConst.DEFAULT_PERT_TYPE
8
+ stage = FreeBenchmarkConst.DEFAULT_STAGE
9
+ dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
10
+ steps: list = []
11
+ ranks: list = []
12
+ dump_path: str = ""
@@ -0,0 +1,17 @@
1
+ from typing import Optional, Any, Tuple, Dict, Callable
2
+
3
+
4
+ class HandlerParams:
5
+ """
6
+ 参数结合体
7
+
8
+ """
9
+ args: Optional[Tuple] = None
10
+ kwargs: Optional[Dict] = None
11
+ index: Optional[int] = None
12
+ original_result: Optional[Any] = None
13
+ fuzzed_result: Optional[Any] = None
14
+ is_consistent: Optional[bool] = True
15
+ save_flag: Optional[bool] = True
16
+ fuzzed_value: Optional[Any] = None
17
+ original_func: Optional[Callable] = None
@@ -0,0 +1,71 @@
1
+ from typing import Any
2
+ from typing import Optional
3
+ from dataclasses import dataclass
4
+
5
+ import mindspore as ms
6
+ from mindspore import Tensor
7
+
8
+ from msprobe.mindspore.runtime import Runtime
9
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
10
+ from .config import Config
11
+ from .handler_params import HandlerParams
12
+
13
+
14
+ class Tools:
15
+
16
+ @staticmethod
17
+ def get_first_tensor_dtype(tensor_seq: Any):
18
+ if isinstance(tensor_seq, Tensor):
19
+ return tensor_seq.dtype
20
+ if isinstance(tensor_seq, (list, tuple)):
21
+ for i in tensor_seq:
22
+ if isinstance(i, Tensor):
23
+ return i.dtype
24
+ raise Exception("The sequence does not contain tensors.")
25
+
26
+ @staticmethod
27
+ def get_default_error_threshold(dtype):
28
+ if Config.pert_type == FreeBenchmarkConst.NO_CHANGE:
29
+ return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
30
+ return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
31
+
32
+
33
+ @dataclass
34
+ class UnequalRow:
35
+ rank: Optional[int] = None
36
+ pert_type: Optional[str] = None
37
+ stage: Optional[str] = None
38
+ step: Optional[int] = None
39
+ api_name: Optional[str] = None
40
+ max_rel: Optional[float] = None
41
+ dtype: Optional[str] = None
42
+ shape: Optional[str] = None
43
+ output_index: Optional[int] = None
44
+
45
+
46
+ def make_unequal_row(
47
+ api_name: str,
48
+ params: HandlerParams,
49
+ ratio: float = None,
50
+ index: int = None,
51
+ ):
52
+ row = UnequalRow(
53
+ api_name=api_name,
54
+ pert_type=Config.pert_type,
55
+ output_index=index,
56
+ stage=Config.stage,
57
+ step=Runtime.step_count
58
+ )
59
+ if isinstance(ratio, float):
60
+ row.max_rel = ratio - 1
61
+ original_tensor = params.original_result
62
+ fuzzed_tensor = params.fuzzed_result
63
+ if index:
64
+ original_tensor = original_tensor[index]
65
+ fuzzed_tensor = fuzzed_tensor[index]
66
+ row.output_index = index
67
+ if isinstance(original_tensor, Tensor):
68
+ row.dtype = original_tensor.dtype
69
+ row.shape = original_tensor.shape
70
+ row.rank = Runtime.rank_id if Runtime.rank_id != -1 else None
71
+ return row