mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,146 +0,0 @@
1
- import re
2
-
3
- from msprobe.core.common.const import Const
4
- from msprobe.core.common.log import logger
5
- from msprobe.core.common.utils import CompareException
6
-
7
-
8
- class Trie:
9
- def __init__(self, type_name=None, has_data=False):
10
- self.type_name = type_name
11
- self.call_count_list = []
12
- self.children = {}
13
- self.has_data = has_data
14
- self.node_type = None
15
-
16
- def __repr__(self):
17
- return (f"Node(type_name={self.type_name}, "
18
- f"has_data={self.has_data}, call number={len(self.call_count_list)})")
19
-
20
- def insert(self, word, word_type="func"):
21
- parts = word.split(Const.SEP)
22
- if len(parts) < 2:
23
- logger.error('result dataframe elements can not be access.')
24
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
25
- """
26
- xxx, node_name, type_name, execute_num
27
- etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1
28
- prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention
29
- node_name: out_proj
30
- type_name: RowParallelLinear
31
- call_count: 1
32
- """
33
- type_name = parts[-2]
34
- call_count = parts[-1]
35
- node = self
36
- prefix_name_list = parts[:-2]
37
-
38
- for name in prefix_name_list:
39
- if name not in node.children:
40
- node.children[name] = Trie()
41
- node = node.children[name]
42
- if node.type_name is None:
43
- node.type_name = name
44
-
45
- node.type_name = type_name
46
- node.has_data = True
47
- node.call_count_list.append(call_count)
48
- node.node_type = word_type
49
-
50
-
51
- class DFSConverter:
52
- def __init__(self, mapping, max_depth=100):
53
- self.mapping = mapping
54
- self.max_depth = max_depth
55
- self.result = {}
56
-
57
- def traverse_and_collect(self, node, path="", mapping_path="", depth=0):
58
- if depth > self.max_depth:
59
- logger.error("The converted data depth is too large, please check the data")
60
- raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
61
-
62
- if node is None:
63
- return self.result
64
-
65
- type_name = node.type_name
66
- if node.has_data:
67
- for count in node.call_count_list:
68
- origin_name = f"{path}.{count}" if node.node_type == "Cell" else f"{path}.{type_name}.{count}"
69
- mapping_name = f"{mapping_path}.{count}" if node.node_type == "Cell" else f"{mapping_path}.{type_name}.{count}"
70
- self.result[origin_name] = mapping_name
71
-
72
- name_mapping = self.mapping.get(type_name, {})
73
-
74
- for child_name, child_node in node.children.items():
75
- new_path = f"{path}.{child_name}" if path else child_name
76
- converted_name = name_mapping.get(child_name, child_name)
77
- new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name
78
- self.traverse_and_collect(child_node, new_path, new_mapping_path, depth+1)
79
-
80
- return self.result
81
-
82
-
83
- def get_mapping_list(ms_tree, mapping):
84
- dfs_converter = DFSConverter(mapping)
85
- ms_pt_mapping = dfs_converter.traverse_and_collect(ms_tree)
86
- mapping_list = []
87
- for ms_name, pt_name in ms_pt_mapping.items():
88
- pt_name = re.sub(r"^Cell", "Module", pt_name)
89
- mapping_list.append((ms_name, pt_name))
90
- return mapping_list
91
-
92
-
93
- def get_prefix_mapping(scope_list):
94
- """layer name to layer name.class_name"""
95
- layer_mapping = {}
96
- for name, v in scope_list.items():
97
- origin_data = v.get("origin_data")
98
- if not origin_data.startswith(("Cell", "Module")):
99
- continue
100
- name_list = name.split(Const.SEP)
101
- if len(name_list) < 2:
102
- logger.error('result dataframe elements can not be access.')
103
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
104
- prefix_name_list = name_list[:-2] + [name_list[-1]]
105
- prefix_name = Const.SEP.join(prefix_name_list)
106
- layer_mapping[prefix_name] = name
107
- return layer_mapping
108
-
109
-
110
- def get_layer_mapping(ms_scope_list, pt_scope_list, mapping):
111
- # 1. get layer prefix to full name mapping
112
- # ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3
113
- ms_prefix2fullname = get_prefix_mapping(ms_scope_list)
114
- # 2. build trie tree
115
- ms_tree = Trie(type_name="Cell")
116
- for k, r in ms_scope_list.items():
117
- origin_data_name = r.get('origin_data')
118
- data_type = origin_data_name.split(Const.SEP)[0]
119
- ms_tree.insert(k, data_type)
120
- msname2ptname = get_mapping_list(ms_tree, mapping)
121
- # 3. get pt layer prefix to full name mapping
122
- # ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3
123
- pt_prefix2fullname = get_prefix_mapping(pt_scope_list)
124
-
125
- final_mapping = []
126
- for ms_name, pt_name in msname2ptname:
127
- final_ms_name = ms_name
128
- final_pt_name = pt_name
129
- # cell
130
- if ms_name in ms_prefix2fullname:
131
- final_ms_name = ms_prefix2fullname.get(ms_name)
132
- final_pt_name = pt_prefix2fullname.get(pt_name, None)
133
- # func
134
- elif final_ms_name in ms_scope_list:
135
- final_ms_name = ms_scope_list.get(ms_name)['origin_data']
136
- # remove forward/backward
137
- final_ms_name = Const.SEP.join(final_ms_name.split(Const.SEP)[:-1])
138
- final_pt_name = pt_scope_list.get(pt_name, None)
139
- if final_pt_name:
140
- final_pt_name = final_pt_name['origin_data']
141
- final_pt_name = Const.SEP.join(final_pt_name.split(Const.SEP)[:-1])
142
- else:
143
- continue
144
- final_mapping.append((final_ms_name, final_pt_name))
145
-
146
- return final_mapping
@@ -1,107 +0,0 @@
1
- from msprobe.core.common.const import Const
2
- from msprobe.core.common.log import logger
3
-
4
- def find_regard_scope(lines, start_sign, end_sign):
5
- # 找出 start_pos 和 end_pos
6
- start_pos = end_pos = -1
7
- for idx, ii in enumerate(lines):
8
- if start_sign in ii:
9
- start_pos = idx
10
- elif end_sign in ii:
11
- end_pos = idx
12
- break
13
- return start_pos, end_pos
14
-
15
-
16
- def find_stack_func_list(lines):
17
- res_list = []
18
- # 过滤和处理 regard_scope
19
- for line in lines:
20
- ele_list = line.split(',')
21
- file_ele = ele_list[Const.STACK_FILE_INDEX]
22
- if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
23
- continue
24
-
25
- func_ele = ele_list[Const.STACK_FUNC_INDEX]
26
- if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
27
- continue
28
-
29
- in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
30
-
31
- res_list.append(in_func_name)
32
- # 反转res_list并生成final_res
33
- reversed_list = res_list[::-1]
34
- return reversed_list
35
-
36
-
37
- def get_duplicated_name(components):
38
- duplicated_components = components
39
- if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit():
40
- logger.warning("key in construct.json is shorter than 3 parts or not name valid.")
41
- else:
42
- # 重复name,如Functional.add.add.X ward
43
- duplicated_components = components[:Const.CONSTRUCT_NAME_INDEX + 1] + components[Const.CONSTRUCT_NAME_INDEX:]
44
- return duplicated_components
45
-
46
-
47
- def modify_mapping_with_stack(stack, construct):
48
- if not stack or not construct:
49
- return {}
50
-
51
- # 是否是mindspore的数据结构
52
- is_ms = any("Cell" in ii for ii in construct)
53
- # 调整后的mapping结构
54
- final_pres = {}
55
- # 查看归属关系
56
- for key in construct:
57
- key_components = key.split(Const.SEP)
58
- code_list = stack.get(key, None)
59
- parent_node = construct.get(key, None)
60
- # 名称如果非标准开头,转为标准开头
61
- if not key.startswith(("Module", "Cell")):
62
- # 如果没有拿到父属scope name,默认顶级域名为Module或Cell
63
- if not parent_node:
64
- # 将节点名字转为标准的Module或Cell
65
- key_components[0] = "Cell" if is_ms else "Module"
66
- # 重复该节点的名字作为类型 如add.add add在-3位置
67
- duplicated_components = get_duplicated_name(key_components)
68
- modified_key = Const.SEP.join(duplicated_components)
69
-
70
- modified_key = modified_key.replace(".forward", "").replace(".backward", "")
71
- final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None}
72
- continue
73
- parent = parent_node.split(Const.SEP)
74
- if len(parent) < 4:
75
- logger.info(f"Parent name in construct.json is not valid")
76
- continue
77
- parent_idx = Const.NAME_FIRST_POSSIBLE_INDEX if not \
78
- parent[Const.NAME_FIRST_POSSIBLE_INDEX].isdigit() else Const.NAME_SECOND_POSSIBLE_INDEX
79
- parent_name = parent[parent_idx]
80
-
81
- if code_list:
82
- # {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number
83
- if parent_name.endswith('s'):
84
- parent_name = parent_name[:-1]
85
- if len(key_components) < 3:
86
- logger.info("The length of key in construct is less than 3, please check")
87
- continue
88
- # {name}.count_number.X ward
89
- func_name = key_components[-3]
90
- start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name)
91
-
92
- # 获取指定范围的代码
93
- regard_scope = code_list[start_pos:end_pos]
94
-
95
- func_stack_list = find_stack_func_list(regard_scope)
96
- else:
97
- func_stack_list = []
98
- # 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]]
99
- final_res_key = Const.SEP.join(parent[:parent_idx + 1] + func_stack_list +
100
- key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:])
101
- final_res_key = final_res_key.strip(".forward").strip(".backward")
102
- else:
103
- final_res_key = Const.SEP.join(key_components[:-2] + [key_components[-1]])
104
- func_stack_list = []
105
- final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: parent_node,
106
- Const.STACK: Const.SEP.join(func_stack_list) if func_stack_list else None}
107
- return final_pres
@@ -1,57 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from msprobe.mindspore.common.const import Const, FreeBenchmarkConst
17
- from msprobe.mindspore.free_benchmark.common.config import Config
18
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
19
- from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
20
- from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
21
-
22
-
23
- class ForwardSelfChecker:
24
-
25
- def __init__(self, api_name: str):
26
- self.api_name = api_name
27
-
28
- def handle(self, params: HandlerParams):
29
- """
30
- 装饰器实际执行逻辑
31
-
32
- """
33
- perturbation = PerturbationFactory.create(self.api_name)
34
- params.fuzzed_result = perturbation.handle(params)
35
- params.original_result = params.original_func(*params.args, **params.kwargs)
36
- if params.fuzzed_result is not False:
37
- return self.deal_fuzzed_and_original_result(params)
38
- return params.original_result
39
-
40
- def get_compare_data(self, params: HandlerParams):
41
- if self.api_name not in Const.COMMUNICATION_API_LIST:
42
- return
43
- # 以下为通讯类api处理逻辑
44
- params.fuzzed_result = params.fuzzed_value
45
- if Config.pert_type == FreeBenchmarkConst.IMPROVE_PRECISION:
46
- params.original_result = params.args
47
- else:
48
- params.original_result = params.args[params.index]
49
-
50
- def deal_fuzzed_and_original_result(self, params: HandlerParams):
51
- original_result = params.original_result
52
- self.get_compare_data(params)
53
- handler = HandlerFactory.create(self.api_name)
54
- result = handler.handle(params)
55
- if self.api_name in Const.COMMUNICATION_API_LIST:
56
- result = original_result
57
- return result
@@ -1,122 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import sys
18
- import traceback
19
- from functools import wraps
20
- from typing import Dict, List, Tuple
21
-
22
- from mindspore import ops
23
-
24
- from msprobe.mindspore.common.log import logger
25
- from msprobe.mindspore.free_benchmark.common.config import Config
26
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
27
- from msprobe.mindspore.free_benchmark.decorator.dec_forward import ForwardSelfChecker
28
- from msprobe.mindspore.runtime import Runtime
29
-
30
-
31
- def decorate(original_func, decorate_func, api_name=None):
32
- """
33
- 总装饰器
34
- """
35
- @wraps(original_func)
36
- def fuzz_wrapper(*args, **kwargs):
37
-
38
- def __exec_decorate_func():
39
- params = data_pre_deal(api_name, original_func, *args, **kwargs)
40
- result = decorate_func(params)
41
- return result
42
-
43
- try:
44
- if Runtime.rank_id == -1:
45
- Runtime.rank_id = os.environ.get("RANK_ID", -1)
46
- if need_wrapper_func():
47
- logger.info(f"[{api_name}] is checking.")
48
- return __exec_decorate_func()
49
- except Exception as e:
50
- logger.error(f"[{api_name}] Error: {str(e)}")
51
- logger.error(f"[{api_name}] Error detail: {traceback.format_exc()}")
52
-
53
- return original_func(*args, **kwargs)
54
-
55
- return fuzz_wrapper
56
-
57
-
58
- def decorate_forward_function(func, api_name=None):
59
- """
60
- 前向装饰器
61
- """
62
-
63
- if not api_name:
64
- api_name = func.__name__
65
-
66
- def forward_func(params: HandlerParams):
67
- forward = ForwardSelfChecker(api_name)
68
- result = forward.handle(params)
69
- return result
70
-
71
- return decorate(func, forward_func, api_name)
72
-
73
-
74
- def stack_depth_check() -> bool:
75
- nested_depth = 1
76
- frame = sys._getframe(1)
77
- while frame:
78
- if frame.f_code.co_name == "fuzz_wrapper":
79
- nested_depth -= 1
80
- if nested_depth < 0:
81
- return False
82
- frame = frame.f_back
83
- return True
84
-
85
-
86
- def get_target_arg_index(args: Tuple) -> int:
87
- """
88
- 类型校验
89
-
90
- """
91
- for i, arg in enumerate(args):
92
- if ops.is_tensor(arg):
93
- if not ops.is_floating_point(arg):
94
- continue
95
- return i
96
- if isinstance(arg, (List, Tuple, Dict)):
97
- return i
98
- return -1
99
-
100
-
101
- def data_pre_deal(api_name, func, *args, **kwargs):
102
- params = HandlerParams()
103
- params.args = args
104
- params.kwargs = kwargs
105
- params.original_func = func
106
- index = get_target_arg_index(args)
107
- if index == -1:
108
- raise Exception(f"{api_name} has no supported input type")
109
- params.index = index
110
- return params
111
-
112
-
113
- def need_wrapper_func():
114
- if not (Runtime.is_running and Config.is_enable):
115
- return False
116
- if not stack_depth_check():
117
- return False
118
- if Config.steps and Runtime.step_count not in Config.steps:
119
- return False
120
- if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
121
- return False
122
- return True
@@ -1,84 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import torch
17
- import torch.nn as nn
18
- from msprobe.core.common.const import Const
19
- from msprobe.core.common.exceptions import MsprobeException
20
- from msprobe.core.data_dump.scope import BaseScope
21
- from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
23
- from msprobe.pytorch.hook_module.api_registry import api_register
24
- from msprobe.pytorch.service import torch_version_above_or_equal_2
25
-
26
- hook_handle_list = []
27
-
28
-
29
- def module_dump(module, dump_name):
30
- if not isinstance(module, nn.Module):
31
- logger.error("The parameter module in module_dump must be a Module subclass.")
32
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
33
- if not isinstance(dump_name, str):
34
- logger.error("The parameter dump_name in module_dump must be a str type.")
35
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
36
-
37
- api_register.api_originality()
38
- register_hook(module, dump_name)
39
-
40
-
41
- def module_dump_end():
42
- api_register.api_modularity()
43
- remove_hook()
44
- hook_handle_list.clear()
45
-
46
-
47
- def register_hook(module, dump_name):
48
- prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP
49
-
50
- pdg = PrecisionDebugger()
51
- _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \
52
- pdg.service.build_hook(BaseScope.Module_Type_Module, prefix)
53
-
54
- if torch_version_above_or_equal_2:
55
- forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
56
- hook_handle_list.append(forward_hook_handle)
57
- else:
58
- pdg.service.check_register_full_backward_hook(module)
59
- full_backward_hook_handle = module.register_full_backward_hook(
60
- pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
61
- forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
62
- hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle])
63
- pdg.service.check_register_full_backward_hook(module)
64
- full_backward_hook_handle = module.register_full_backward_hook(backward_hook)
65
-
66
- forward_pre_hook_handle = module.register_forward_pre_hook(
67
- pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
68
- forward_hook_handle = module.register_forward_hook(
69
- pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
70
- hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle])
71
-
72
- if torch_version_above_or_equal_2:
73
- backward_pre_hook_handle = module.register_full_backward_pre_hook(
74
- pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
75
- pdg.service.check_register_full_backward_hook(module)
76
- full_backward_hook_handle = module.register_full_backward_hook(
77
- pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
78
- hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle])
79
-
80
-
81
- def remove_hook():
82
- for hook_handle in hook_handle_list:
83
- if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
84
- hook_handle.remove()