mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,380 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import multiprocessing
19
+ from functools import partial
20
+
21
+ import pandas as pd
22
+ from tqdm import tqdm
23
+
24
+ from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
25
+ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
26
+ from msprobe.core.common.utils import CompareException, add_time_with_xlsx
27
+ from msprobe.core.compare.utils import table_value_is_valid
28
+
29
+
30
+ def check_compare_result_name(file_name):
31
+ """
32
+ check whether the compare result name is as expected
33
+ """
34
+ single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$"
35
+ multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$"
36
+ if re.match(multi_ranks_pattern, file_name):
37
+ return True
38
+ if re.match(single_rank_pattern, file_name):
39
+ logger.warning("Single rank compare result do not need to be merged.")
40
+ return False
41
+ logger.error(f"Wrong compare result name: {file_name}, please check!")
42
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
43
+
44
+
45
+ def reorder_path(compare_result_path_list):
46
+ """
47
+ reorder compare results by rank num
48
+ """
49
+ rank_pattern = r"compare_result_rank(\d+)-rank"
50
+ reorder_path_list = sorted(
51
+ compare_result_path_list,
52
+ key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
53
+ )
54
+ return reorder_path_list
55
+
56
+
57
+ def get_result_path(input_dir):
58
+ """
59
+ get rank ordered compare result file path list
60
+ """
61
+ compare_result_path_list = [os.path.join(input_dir, f)
62
+ for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
63
+ filt_compare_result_path_list = []
64
+ for file_path in compare_result_path_list:
65
+ file_name = os.path.basename(file_path)
66
+ if check_compare_result_name(file_name):
67
+ compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
68
+ compare_result_path = compare_result_path_checker.common_check()
69
+ filt_compare_result_path_list.append(compare_result_path)
70
+
71
+ filt_compare_result_path_list = reorder_path(filt_compare_result_path_list) # 多卡比对结果按rank序号重新排序
72
+
73
+ if len(filt_compare_result_path_list) < 2:
74
+ logger.warning("Number of compare result is no more than 1, no need to merge.") # 单卡结果无需合并,直接退出
75
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
76
+ return filt_compare_result_path_list
77
+
78
+
79
+ def get_dump_mode(result_df, rank_num):
80
+
81
+ """
82
+ get dump mode from header of first compare result table
83
+ """
84
+ header = result_df.columns.tolist()
85
+ if header in [CompareConst.COMPARE_RESULT_HEADER + [CompareConst.DATA_NAME],
86
+ CompareConst.COMPARE_RESULT_HEADER_STACK + [CompareConst.DATA_NAME]]:
87
+ return Const.ALL
88
+ elif header in [CompareConst.SUMMARY_COMPARE_RESULT_HEADER, CompareConst.SUMMARY_COMPARE_RESULT_HEADER_STACK]:
89
+ return Const.SUMMARY
90
+ elif header in [CompareConst.MD5_COMPARE_RESULT_HEADER, CompareConst.MD5_COMPARE_RESULT_HEADER_STACK]:
91
+ return Const.MD5
92
+ else:
93
+ logger.warning(f"A valid dump task can not be identified from rank{rank_num} compare result, please check! "
94
+ f"The compare result will not be shown in merged result.")
95
+ return ""
96
+
97
+
98
+ def check_index_dump_mode_consistent(dump_mode, rank_num):
99
+ """
100
+ check compare index to merge is consistent with dump mode
101
+ if compare_index_list is None, return all compare_indexes of dump mode
102
+ """
103
+ if dump_mode == Const.MD5:
104
+ logger.warning(f"Rank{rank_num} compare result is 'md5' dump task and does not support merging result, please "
105
+ f"check! The compare result will not be shown in merged result.")
106
+ return []
107
+
108
+ dump_mode_compare_index_map = {
109
+ Const.ALL: CompareConst.ALL_COMPARE_INDEX,
110
+ Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX
111
+ }
112
+ valid_compare_index = dump_mode_compare_index_map.get(dump_mode)
113
+
114
+ share_list = list(share_compare_index_list)
115
+
116
+ # 如果传入的compare_index_list为空,则比对指标为dump_mode对应的全部比对指标
117
+ if not share_list:
118
+ share_compare_index_list.extend(valid_compare_index)
119
+ return list(share_compare_index_list)
120
+ if set(share_list).issubset(valid_compare_index):
121
+ return share_list
122
+ else:
123
+ invalid_compare_index = set(valid_compare_index) - set(share_list)
124
+ logger.warning(f"Compare indexes in rank{rank_num} compare result are not consistent with "
125
+ f"those in other compare results, please check!")
126
+ logger.warning(f"The compare result will not be shown in merged result.")
127
+ logger.warning(f"The invalid compare indexes: {invalid_compare_index}")
128
+ return []
129
+
130
+
131
+ def extract_api_full_name(api_list, result_df, rank_num):
132
+ """
133
+ find api full name from compare result according to api list
134
+ """
135
+ api_full_name_list = []
136
+ for api in api_list:
137
+ api_pat = api + Const.SEP
138
+ escaped_api_pat = api_pat.replace('.', r'\.')
139
+ single_api_full_name_list = result_df.loc[
140
+ result_df[CompareConst.NPU_NAME].str.contains(escaped_api_pat, na=False), CompareConst.NPU_NAME].tolist()
141
+ if len(single_api_full_name_list) == 0:
142
+ logger.warning(f"{api} not found in rank{rank_num} compare result.")
143
+ continue
144
+ api_full_name_list.extend(single_api_full_name_list)
145
+ return api_full_name_list
146
+
147
+
148
+ def search_api_index_result(api_list, compare_index_list, result_df, rank_num, compare_index_dict):
149
+ """
150
+ parsing single rank compare result into the intermediate target dict
151
+ {
152
+ compare_index1: {
153
+ api_full_name1:{
154
+ rank1: value,
155
+ },
156
+ api_full_name2,
157
+ ...
158
+ },
159
+ compare_index2: {},
160
+ ...
161
+ }
162
+ """
163
+ api_full_name_list = extract_api_full_name(api_list, result_df, rank_num)
164
+ for compare_index in compare_index_list:
165
+ api_index_dict = {}
166
+ for api_full_name in api_full_name_list:
167
+ table_value_check(api_full_name)
168
+ row_num = result_df.index[result_df[CompareConst.NPU_NAME] == api_full_name].tolist()[0]
169
+ index_value = result_df.loc[row_num, compare_index]
170
+ table_value_check(index_value)
171
+ api_index_dict.setdefault(api_full_name, {})[rank_num] = index_value # update api_index_dict
172
+ compare_index_dict[compare_index] = api_index_dict
173
+ return compare_index_dict
174
+
175
+
176
+ def table_value_check(value):
177
+ if not table_value_is_valid(value):
178
+ raise RuntimeError(
179
+ f"Malicious value [{value}] is not allowed to be written into the merged xlsx.")
180
+
181
+
182
+ def result_process(compare_result_path_list, api_list):
183
+ """
184
+ process compare results into target intermediate dict list
185
+ """
186
+ compare_index_dict_list = []
187
+ rank_num_list = []
188
+ compare_index_list = []
189
+
190
+ for compare_result_path in compare_result_path_list:
191
+ compare_index_dict = {}
192
+ result_df = read_xlsx(compare_result_path)
193
+
194
+ rank_pattern = r"compare_result_rank(\d+)-rank"
195
+ rank_num = int(re.search(rank_pattern, os.path.basename(compare_result_path)).group(1))
196
+ logger.info(f"Parsing rank{rank_num} compare result...")
197
+ if not result_df.empty:
198
+ dump_mode = get_dump_mode(result_df, rank_num)
199
+ if dump_mode == "":
200
+ return [], [], []
201
+ # 因为compare_index是指定的,固定不变,所以一旦compare_index是确定的,dump_mode也是确定的,
202
+ # 所以只要校验compare_index和dump_mode一致性就能保证所有rank的结果都是dump_mode一致的
203
+ compare_index_list = check_index_dump_mode_consistent(dump_mode, rank_num)
204
+ if len(compare_index_list) == 0:
205
+ return [], [], []
206
+ compare_index_dict = search_api_index_result(api_list, share_compare_index_list,
207
+ result_df, rank_num, compare_index_dict)
208
+ compare_index_dict_list.append(compare_index_dict)
209
+ rank_num_list.append(rank_num)
210
+ else:
211
+ logger.warning(f"Rank{rank_num} compare result is empty and will not shown in merged result.")
212
+
213
+ return compare_index_dict_list, rank_num_list, compare_index_list
214
+
215
+
216
+ def handle_multi_process(func, func_args, lock):
217
+ compare_result_path_list, api_list = func_args
218
+
219
+ result_num = len(compare_result_path_list)
220
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
221
+ if result_num <= process_num:
222
+ process_num = result_num
223
+ chunks = [[compare_result_path] for compare_result_path in compare_result_path_list]
224
+ else:
225
+ chunk_size = result_num // process_num
226
+ chunks = [compare_result_path_list[i:i + chunk_size] for i in range(0, result_num, chunk_size)]
227
+
228
+ pool = multiprocessing.Pool(process_num)
229
+
230
+ def err_call(args):
231
+ logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
232
+ try:
233
+ pool.terminate()
234
+ except OSError:
235
+ logger.error("Pool terminate failed")
236
+
237
+ progress_bar = tqdm(total=result_num, desc="Compare Result Parsing Process", unit="num", ncols=100)
238
+
239
+ def update_progress(size, progress_lock, extra_param=None):
240
+ with progress_lock:
241
+ progress_bar.update(size)
242
+
243
+ results = []
244
+ for chunk in chunks:
245
+ chunk_size = len(chunk)
246
+ result = pool.apply_async(func, # pool.apply_async立即返回ApplyResult对象,因此results中结果是顺序的
247
+ args=(chunk, api_list),
248
+ error_callback=err_call,
249
+ callback=partial(update_progress, chunk_size, lock)
250
+ )
251
+ results.append(result)
252
+
253
+ all_compare_index_dict_list = []
254
+ all_rank_num_list = []
255
+ all_compare_index_list_list = []
256
+ for result in results:
257
+ compare_index_dict, rank_num_list, compare_index_list = result.get()
258
+ all_compare_index_dict_list.append(compare_index_dict)
259
+ all_rank_num_list.append(rank_num_list)
260
+ all_compare_index_list_list.append(compare_index_list)
261
+
262
+ pool.close()
263
+ pool.join()
264
+
265
+ if not any(all_compare_index_dict_list):
266
+ logger.warning("Nothing to merge.")
267
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
268
+
269
+ return all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list
270
+
271
+
272
+ def generate_result_df(api_index_dict, header):
273
+ """
274
+ Generates a DataFrame from the given api_index_dict and header.
275
+ api_index_dict:
276
+ {
277
+ api_full_name1:{
278
+ rank1: value,
279
+ },
280
+ api_full_name2:{
281
+ rank1: value
282
+ },
283
+ ...
284
+ }
285
+ """
286
+ result = []
287
+ for api_full_name, rank_value_dict in api_index_dict.items():
288
+ result_item = [api_full_name]
289
+ result_item.extend(rank_value_dict.values())
290
+ result.append(result_item)
291
+ return pd.DataFrame(result, columns=header, dtype="object")
292
+
293
+
294
+ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list, output_dir):
295
+ """
296
+ generate merge result from the intermediate dict.
297
+ one compare index, one sheet
298
+ """
299
+ file_name = add_time_with_xlsx("multi_ranks_compare_merge")
300
+ output_path = os.path.join(output_dir, file_name)
301
+
302
+ compare_index_list = None
303
+ for item in all_compare_index_list_list:
304
+ if item:
305
+ compare_index_list = item
306
+ break
307
+ if not compare_index_list:
308
+ logger.error("No compare index recognized, please check!")
309
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
310
+
311
+ all_result_df_list = []
312
+ for compare_index_dict_list, rank_num_list in zip(all_compare_index_dict_list, all_rank_num_list):
313
+ for compare_index_dict, rank_num in zip(compare_index_dict_list, rank_num_list):
314
+ header = [CompareConst.NPU_NAME, "rank" + str(rank_num)]
315
+ result_df_list = []
316
+ for _, api_index_dict in compare_index_dict.items():
317
+ result_df = generate_result_df(api_index_dict, header)
318
+ result_df_list.append(result_df)
319
+ all_result_df_list.append(result_df_list)
320
+
321
+ merge_df_list = df_merge(all_result_df_list)
322
+ final_result_df_list = []
323
+ for i, df in enumerate(merge_df_list):
324
+ # merge_df_list中df与compare_index_list中compare_index一一对应
325
+ final_result_df_list.append((df, compare_index_list[i]))
326
+ save_excel(output_path, final_result_df_list)
327
+ logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
328
+
329
+
330
+ def df_merge(all_result_df_list):
331
+ """
332
+ merge different rank result_df
333
+ """
334
+ if len(all_result_df_list) == 0:
335
+ logger.warning("Nothing to merge.")
336
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
337
+ if len(all_result_df_list) == 1:
338
+ logger.info("Only one compare result gets merge data.")
339
+ merge_df_base = all_result_df_list[0]
340
+ for sublist in all_result_df_list[1:]:
341
+ for i, sub_df in enumerate(sublist):
342
+ merge_df_base[i] = pd.merge(merge_df_base[i], sub_df, on=CompareConst.NPU_NAME, how='outer')
343
+ for i, value in enumerate(merge_df_base):
344
+ merge_df_base[i] = value.reindex(
345
+ columns=[CompareConst.NPU_NAME] + [col for col in value.columns if col != CompareConst.NPU_NAME])
346
+ return merge_df_base
347
+
348
+
349
+ share_compare_index_list = []
350
+
351
+
352
+ def initialize_compare_index(config):
353
+ global share_compare_index_list
354
+ manager = multiprocessing.Manager()
355
+ share_compare_index_list = manager.list(config.get("compare_index", [])) # 创建共享全局列表
356
+
357
+
358
+ def merge_result(input_dir, output_dir, config_path):
359
+ input_dir = FileChecker(input_dir, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
360
+ create_directory(output_dir)
361
+
362
+ compare_result_path_list = get_result_path(input_dir) # 获得的input_dir中所有比对结果件的全路径,数量少于2,便提示退出
363
+
364
+ config = load_yaml(config_path)
365
+ if not config:
366
+ logger.error('config.yaml is empty, please check.')
367
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
368
+ api_list = config.get('api')
369
+ if not api_list:
370
+ logger.error('The APIs required to merge data were not found')
371
+ raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR)
372
+
373
+ # 初始化共享全局变量share_compare_index_list
374
+ initialize_compare_index(config)
375
+
376
+ func_args = (compare_result_path_list, api_list)
377
+ all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list = (
378
+ handle_multi_process(result_process, func_args, multiprocessing.Manager().RLock()))
379
+
380
+ generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_compare_index_list_list, output_dir)
@@ -0,0 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.compare.merge_result.merge_result import merge_result
17
+
18
+
19
+ def _merge_result_parser(parser):
20
+ parser.add_argument("-i", "--input_dir", dest="input_dir", type=str,
21
+ help="<Required> The compare result path, a dir.", required=True)
22
+ parser.add_argument("-o", "--output_dir", dest="output_dir", type=str,
23
+ help="<Required> The result merge output path, a dir.", required=True)
24
+ parser.add_argument("-config", "--config-path", dest="config_path", type=str,
25
+ help="<Required> Yaml path containing distribute APIs and compare indexes for merging data "
26
+ "from compare results.",
27
+ required=True)
28
+
29
+
30
+ def merge_result_cli(args):
31
+ merge_result(args.input_dir, args.output_dir, args.config_path)
@@ -23,7 +23,7 @@ from msprobe.core.common.const import CompareConst
23
23
 
24
24
 
25
25
  def _handle_multi_process(func, input_parma, result_df, lock):
26
- process_num = int((multiprocessing.cpu_count() + 1) / 2)
26
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
27
27
  op_name_mapping_dict = read_dump_data(result_df)
28
28
 
29
29
  df_chunk_size = len(result_df) // process_num
@@ -63,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
63
63
 
64
64
 
65
65
  def _ms_graph_handle_multi_process(func, result_df, mode):
66
- process_num = int((multiprocessing.cpu_count() + 1) // 4)
66
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
67
67
  df_chunk_size = len(result_df) // process_num
68
68
  if df_chunk_size > 0:
69
69
  df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]