mindstudio-probe 8.3.3__py3-none-any.whl → 26.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (689) hide show
  1. {mindstudio_probe-8.3.3.dist-info → mindstudio_probe-26.0.0a1.dist-info}/METADATA +26 -14
  2. mindstudio_probe-26.0.0a1.dist-info/RECORD +498 -0
  3. {mindstudio_probe-8.3.3.dist-info → mindstudio_probe-26.0.0a1.dist-info}/WHEEL +1 -1
  4. mindstudio_probe-26.0.0a1.dist-info/entry_points.txt +5 -0
  5. mindstudio_probe-26.0.0a1.dist-info/licenses/LICENSE +124 -0
  6. mindstudio_probe-26.0.0a1.dist-info/top_level.txt +2 -0
  7. msprobe/__init__.py +12 -13
  8. msprobe/config.json +9 -31
  9. msprobe/core/__init__.py +12 -11
  10. msprobe/core/acc_check/acc_check_cli.py +145 -0
  11. msprobe/core/common/const.py +97 -38
  12. msprobe/core/common/db_manager.py +133 -12
  13. msprobe/core/common/decorator.py +12 -11
  14. msprobe/core/common/exceptions.py +12 -11
  15. msprobe/core/common/file_utils.py +101 -25
  16. msprobe/core/common/framework_adapter.py +36 -25
  17. msprobe/core/common/global_lock.py +12 -11
  18. msprobe/core/common/inplace_op_checker.py +12 -11
  19. msprobe/core/common/log.py +22 -11
  20. msprobe/core/common/megatron_utils.py +566 -11
  21. msprobe/core/common/parallel_state.py +12 -11
  22. msprobe/core/common/runtime.py +12 -11
  23. msprobe/core/common/utils.py +41 -41
  24. msprobe/core/compare/acc_compare.py +361 -104
  25. msprobe/core/compare/atb_data_compare.py +422 -0
  26. msprobe/core/compare/auto_compare.py +134 -0
  27. msprobe/core/compare/check.py +14 -17
  28. msprobe/core/compare/compare_cli.py +72 -149
  29. msprobe/core/compare/config.py +12 -13
  30. msprobe/core/compare/diff_analyze/first_diff_analyze.py +28 -15
  31. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  32. msprobe/core/compare/find_first/analyzer.py +18 -18
  33. msprobe/core/compare/find_first/graph.py +12 -11
  34. msprobe/core/compare/find_first/utils.py +13 -12
  35. msprobe/core/compare/indicator_analysis/__init__.py +15 -0
  36. msprobe/core/compare/indicator_analysis/algorithm.py +363 -0
  37. msprobe/core/compare/indicator_analysis/api_data.py +141 -0
  38. msprobe/core/compare/indicator_analysis/calculator.py +181 -0
  39. msprobe/core/compare/indicator_analysis/utils.py +116 -0
  40. msprobe/core/compare/layer_mapping/__init__.py +12 -11
  41. msprobe/core/compare/layer_mapping/data_scope_parser.py +20 -11
  42. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -13
  43. msprobe/core/compare/layer_mapping/postprocess_pass.py +13 -11
  44. msprobe/core/compare/merge_result/merge_result.py +12 -11
  45. msprobe/core/compare/merge_result/merge_result_cli.py +12 -11
  46. msprobe/core/compare/merge_result/utils.py +12 -11
  47. msprobe/core/compare/multiprocessing_compute.py +13 -14
  48. msprobe/core/compare/npy_compare.py +13 -11
  49. msprobe/core/compare/offline_data_compare.py +160 -0
  50. msprobe/core/compare/stats_diff_calc.py +39 -0
  51. msprobe/core/compare/torchair_acc_cmp.py +764 -0
  52. msprobe/core/compare/torchair_cmp_utils.py +338 -0
  53. msprobe/core/compare/utils.py +140 -49
  54. msprobe/core/config_check/__init__.py +12 -11
  55. msprobe/core/config_check/checkers/__init__.py +12 -11
  56. msprobe/core/config_check/checkers/base_checker.py +15 -14
  57. msprobe/core/config_check/checkers/dataset_checker.py +13 -12
  58. msprobe/core/config_check/checkers/env_args_checker.py +13 -12
  59. msprobe/core/config_check/checkers/hyperparameter_checker.py +16 -15
  60. msprobe/core/config_check/checkers/pip_checker.py +15 -15
  61. msprobe/core/config_check/checkers/random_checker.py +13 -12
  62. msprobe/core/config_check/checkers/weights_checker.py +14 -12
  63. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +13 -17
  64. msprobe/core/config_check/ckpt_compare/megatron_loader.py +13 -12
  65. msprobe/core/config_check/ckpt_compare/metrics.py +12 -11
  66. msprobe/core/config_check/config_check_cli.py +18 -17
  67. msprobe/core/config_check/config_checker.py +16 -14
  68. msprobe/core/config_check/resource/dependency.yaml +15 -12
  69. msprobe/core/config_check/resource/env.yaml +12 -11
  70. msprobe/core/config_check/utils/hyperparameter_parser.py +12 -11
  71. msprobe/core/config_check/utils/utils.py +12 -11
  72. msprobe/core/{data_dump → dump/api_dump}/api_registry.py +12 -11
  73. msprobe/core/{common_config.py → dump/common_config.py} +13 -24
  74. msprobe/core/dump/data_dump/data_collector.py +257 -0
  75. msprobe/core/{data_dump → dump/data_dump}/data_processor/base.py +45 -36
  76. msprobe/core/{data_dump → dump/data_dump}/data_processor/factory.py +33 -25
  77. msprobe/core/{data_dump → dump/data_dump}/data_processor/mindspore_processor.py +37 -113
  78. msprobe/core/{data_dump → dump/data_dump}/data_processor/pytorch_processor.py +364 -131
  79. msprobe/core/{data_dump → dump/data_dump}/json_writer.py +24 -31
  80. msprobe/core/{data_dump → dump/data_dump}/scope.py +12 -13
  81. msprobe/core/{debugger → dump/debugger}/precision_debugger.py +15 -23
  82. msprobe/core/dump/dump2db/db_utils.py +215 -0
  83. msprobe/core/dump/dump2db/dump2db.py +409 -0
  84. msprobe/core/{hook_manager.py → dump/hook_manager.py} +38 -87
  85. msprobe/core/dump/kernel_dump/kernel_config.py +34 -0
  86. msprobe/core/{service.py → dump/service.py} +43 -27
  87. msprobe/core/install_deps/install_deps.py +51 -0
  88. msprobe/core/monitor/anomaly_processor.py +13 -11
  89. msprobe/core/monitor/csv2db.py +73 -93
  90. msprobe/core/monitor/db_utils.py +140 -205
  91. msprobe/core/monitor/utils.py +18 -17
  92. msprobe/core/monitor_v2/__init__.py +20 -0
  93. msprobe/core/monitor_v2/base.py +83 -0
  94. msprobe/core/monitor_v2/cc.py +287 -0
  95. msprobe/core/monitor_v2/factory.py +81 -0
  96. msprobe/core/monitor_v2/module.py +201 -0
  97. msprobe/core/monitor_v2/optimizer.py +245 -0
  98. msprobe/core/monitor_v2/param.py +154 -0
  99. msprobe/core/monitor_v2/trainer.py +326 -0
  100. msprobe/core/monitor_v2/utils.py +122 -0
  101. msprobe/core/monitor_v2/weight_grad.py +419 -0
  102. msprobe/core/monitor_v2/writer.py +162 -0
  103. msprobe/core/overflow_check/abnormal_scene.py +12 -11
  104. msprobe/core/overflow_check/api_info.py +12 -11
  105. msprobe/core/overflow_check/checker.py +12 -11
  106. msprobe/core/overflow_check/filter.py +13 -11
  107. msprobe/core/overflow_check/level.py +12 -11
  108. msprobe/core/overflow_check/utils.py +12 -11
  109. msprobe/core/single_save/single_comparator.py +12 -11
  110. msprobe/core/single_save/single_saver.py +12 -11
  111. msprobe/infer/__init__.py +16 -0
  112. msprobe/infer/offline/__init__.py +16 -0
  113. msprobe/infer/offline/compare/__init__.py +16 -0
  114. msprobe/infer/offline/compare/msquickcmp/__init__.py +16 -0
  115. msprobe/infer/offline/compare/msquickcmp/adapter_cli/__init__.py +16 -0
  116. msprobe/infer/offline/compare/msquickcmp/adapter_cli/args_adapter.py +46 -0
  117. msprobe/infer/offline/compare/msquickcmp/atc/__init__.py +16 -0
  118. msprobe/infer/offline/compare/msquickcmp/atc/atc_utils.py +98 -0
  119. msprobe/infer/offline/compare/msquickcmp/cmp_process.py +328 -0
  120. msprobe/infer/offline/compare/msquickcmp/common/__init__.py +16 -0
  121. msprobe/infer/offline/compare/msquickcmp/common/args_check.py +112 -0
  122. msprobe/infer/offline/compare/msquickcmp/common/convert.py +74 -0
  123. msprobe/infer/offline/compare/msquickcmp/common/dump_data.py +121 -0
  124. msprobe/infer/offline/compare/msquickcmp/common/dynamic_argument_bean.py +39 -0
  125. msprobe/infer/offline/compare/msquickcmp/common/utils.py +669 -0
  126. msprobe/infer/offline/compare/msquickcmp/config.ini +6 -0
  127. msprobe/infer/offline/compare/msquickcmp/dump/__init__.py +16 -0
  128. msprobe/infer/offline/compare/msquickcmp/dump/args_adapter.py +50 -0
  129. msprobe/infer/offline/compare/msquickcmp/dump/dump_process.py +91 -0
  130. msprobe/infer/offline/compare/msquickcmp/install_aclruntime_aisbench.sh +180 -0
  131. msprobe/infer/offline/compare/msquickcmp/main.py +199 -0
  132. msprobe/infer/offline/compare/msquickcmp/net_compare/__init__.py +16 -0
  133. msprobe/infer/offline/compare/msquickcmp/net_compare/net_compare.py +277 -0
  134. msprobe/infer/offline/compare/msquickcmp/npu/__init__.py +16 -0
  135. msprobe/infer/offline/compare/msquickcmp/npu/npu_dump_data.py +558 -0
  136. msprobe/infer/offline/compare/msquickcmp/npu/om_parser.py +416 -0
  137. msprobe/infer/offline/compare/msquickcmp/onnx_model/__init__.py +16 -0
  138. msprobe/infer/offline/compare/msquickcmp/onnx_model/onnx_dump_data.py +374 -0
  139. msprobe/infer/utils/__init__.py +15 -0
  140. msprobe/infer/utils/acc_cmp.py +94 -0
  141. msprobe/infer/utils/check/__init__.py +37 -0
  142. msprobe/infer/utils/check/args_checker.py +35 -0
  143. msprobe/infer/utils/check/checker.py +227 -0
  144. msprobe/infer/utils/check/dict_checker.py +78 -0
  145. msprobe/infer/utils/check/func_wrapper.py +96 -0
  146. msprobe/infer/utils/check/list_checker.py +56 -0
  147. msprobe/infer/utils/check/number_checker.py +64 -0
  148. msprobe/infer/utils/check/obj_checker.py +41 -0
  149. msprobe/infer/utils/check/path_checker.py +249 -0
  150. msprobe/infer/utils/check/rule.py +126 -0
  151. msprobe/infer/utils/check/string_checker.py +66 -0
  152. msprobe/infer/utils/cmp_algorithm.py +261 -0
  153. msprobe/infer/utils/constants.py +112 -0
  154. msprobe/infer/utils/file_open_check.py +337 -0
  155. msprobe/infer/utils/util.py +177 -0
  156. msprobe/mindspore/__init__.py +14 -13
  157. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +14 -13
  158. msprobe/mindspore/api_accuracy_checker/api_info.py +12 -11
  159. msprobe/mindspore/api_accuracy_checker/api_runner.py +12 -11
  160. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +12 -11
  161. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +12 -11
  162. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +12 -11
  163. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +12 -11
  164. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +15 -14
  165. msprobe/mindspore/api_accuracy_checker/compute_element.py +12 -11
  166. msprobe/mindspore/api_accuracy_checker/data_manager.py +13 -11
  167. msprobe/mindspore/api_accuracy_checker/main.py +12 -11
  168. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +14 -12
  169. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +13 -11
  170. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +12 -11
  171. msprobe/mindspore/api_accuracy_checker/type_mapping.py +12 -11
  172. msprobe/mindspore/api_accuracy_checker/utils.py +12 -11
  173. msprobe/mindspore/common/const.py +15 -74
  174. msprobe/mindspore/common/log.py +12 -11
  175. msprobe/mindspore/common/utils.py +30 -15
  176. msprobe/mindspore/compare/common_dir_compare.py +21 -23
  177. msprobe/mindspore/compare/distributed_compare.py +18 -16
  178. msprobe/mindspore/compare/ms_compare.py +14 -14
  179. msprobe/mindspore/compare/ms_graph_compare.py +26 -20
  180. msprobe/mindspore/compare/utils.py +14 -12
  181. msprobe/mindspore/{cell_processor.py → dump/cell_processor.py} +15 -14
  182. msprobe/mindspore/{debugger → dump/debugger}/debugger_config.py +12 -30
  183. msprobe/mindspore/{debugger → dump/debugger}/precision_debugger.py +43 -45
  184. msprobe/mindspore/dump/{cell_dump_process.py → dump_processor/cell_dump_process.py} +31 -17
  185. msprobe/mindspore/dump/{cell_dump_with_insert_gradient.py → dump_processor/cell_dump_with_insert_gradient.py} +18 -14
  186. msprobe/mindspore/dump/{dump_tool_factory.py → dump_processor/dump_tool_factory.py} +16 -15
  187. msprobe/mindspore/dump/{graph_mode_cell_dump.py → dump_processor/graph_mode_cell_dump.py} +16 -15
  188. msprobe/mindspore/dump/{graph_tensor_dump.py → dump_processor/graph_tensor_dump.py} +134 -133
  189. msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/api_register.py +15 -14
  190. msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/hook_cell.py +12 -11
  191. msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/ms_hook_manager.py +47 -20
  192. msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/primitive_hooks.py +14 -13
  193. msprobe/mindspore/dump/{hook_cell → dump_processor/hook_cell}/support_wrap_ops.yaml +13 -11
  194. msprobe/mindspore/dump/{jit_dump.py → dump_processor/jit_dump.py} +14 -13
  195. msprobe/mindspore/dump/{kernel_graph_dump.py → dump_processor/kernel_graph_dump.py} +13 -12
  196. msprobe/mindspore/dump/{kernel_kbyk_dump.py → dump_processor/kernel_kbyk_dump.py} +13 -12
  197. msprobe/mindspore/{exception_dump → dump/exception_dump}/exception_dump_tool_factory.py +14 -13
  198. msprobe/mindspore/{exception_dump → dump/exception_dump}/kernel_graph_exception_dump.py +13 -12
  199. msprobe/mindspore/{mindspore_service.py → dump/mindspore_service.py} +18 -17
  200. msprobe/mindspore/dump/mindtorch/__init__.py +19 -0
  201. msprobe/mindspore/dump/ms_config.py +105 -0
  202. msprobe/mindspore/{overflow_check → dump/overflow_check}/kernel_graph_overflow_check.py +13 -12
  203. msprobe/mindspore/{overflow_check → dump/overflow_check}/overflow_check_tool_factory.py +14 -13
  204. msprobe/mindspore/dump/task_handler_factory.py +43 -0
  205. msprobe/mindspore/monitor/common_func.py +12 -11
  206. msprobe/mindspore/monitor/data_writers.py +12 -11
  207. msprobe/mindspore/monitor/distributed/wrap_distributed.py +93 -39
  208. msprobe/mindspore/monitor/features.py +12 -11
  209. msprobe/mindspore/monitor/module_hook.py +19 -22
  210. msprobe/mindspore/monitor/optimizer_collect.py +29 -25
  211. msprobe/mindspore/monitor/utils.py +13 -11
  212. msprobe/msaccucmp/advisor/__init__.py +16 -0
  213. msprobe/msaccucmp/advisor/advisor_const.py +65 -0
  214. msprobe/msaccucmp/advisor/advisor_result.py +73 -0
  215. msprobe/msaccucmp/advisor/compare_advisor.py +99 -0
  216. msprobe/msaccucmp/advisor/input_advisor.py +66 -0
  217. msprobe/msaccucmp/advisor/node_advisor.py +68 -0
  218. msprobe/msaccucmp/advisor/overflow_advisor.py +58 -0
  219. msprobe/msaccucmp/algorithm_manager/__init__.py +16 -0
  220. msprobe/msaccucmp/algorithm_manager/algorithm_manager.py +464 -0
  221. msprobe/msaccucmp/algorithm_manager/algorithm_parameter.py +42 -0
  222. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_AccumulatedRelativeError.py +46 -0
  223. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_CosineSimilarity.py +58 -0
  224. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_KullbackLeiblerDivergence.py +84 -0
  225. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxAbsoluteError.py +41 -0
  226. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MaxRelativeError.py +46 -0
  227. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanAbsoluteError.py +41 -0
  228. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_MeanRelativeError.py +46 -0
  229. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RelativeEuclideanDistance.py +46 -0
  230. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_RootMeanSquareError.py +40 -0
  231. msprobe/msaccucmp/algorithm_manager/builtin_algorithm/alg_StandardDeviation.py +47 -0
  232. msprobe/msaccucmp/cmp_utils/__init__.py +16 -0
  233. msprobe/msaccucmp/cmp_utils/common.py +113 -0
  234. msprobe/msaccucmp/cmp_utils/constant/__init__.py +16 -0
  235. msprobe/msaccucmp/cmp_utils/constant/compare_error.py +81 -0
  236. msprobe/msaccucmp/cmp_utils/constant/const_manager.py +530 -0
  237. msprobe/msaccucmp/cmp_utils/file_utils.py +497 -0
  238. msprobe/msaccucmp/cmp_utils/log.py +257 -0
  239. msprobe/msaccucmp/cmp_utils/multi_process/__init__.py +16 -0
  240. msprobe/msaccucmp/cmp_utils/multi_process/multi_convert_process.py +140 -0
  241. msprobe/msaccucmp/cmp_utils/multi_process/progress.py +78 -0
  242. msprobe/msaccucmp/cmp_utils/path_check.py +274 -0
  243. msprobe/msaccucmp/cmp_utils/reg_manager.py +98 -0
  244. msprobe/msaccucmp/cmp_utils/tlv_parse.py +279 -0
  245. msprobe/msaccucmp/cmp_utils/utils.py +356 -0
  246. msprobe/msaccucmp/cmp_utils/utils_type.py +63 -0
  247. msprobe/msaccucmp/compare_vector.py +48 -0
  248. msprobe/msaccucmp/conversion/__init__.py +16 -0
  249. msprobe/msaccucmp/conversion/data_conversion.py +277 -0
  250. msprobe/msaccucmp/conversion/dtype_conversion.py +99 -0
  251. msprobe/msaccucmp/conversion/shape_format_conversion.py +477 -0
  252. msprobe/msaccucmp/conversion/tensor_conversion.py +369 -0
  253. msprobe/msaccucmp/dump_data_conversion.py +46 -0
  254. msprobe/msaccucmp/dump_parse/__init__.py +16 -0
  255. msprobe/msaccucmp/dump_parse/big_dump_data.py +317 -0
  256. msprobe/msaccucmp/dump_parse/dump.py +423 -0
  257. msprobe/msaccucmp/dump_parse/dump_data_object.py +322 -0
  258. msprobe/msaccucmp/dump_parse/dump_data_parser.py +436 -0
  259. msprobe/msaccucmp/dump_parse/dump_utils.py +246 -0
  260. msprobe/msaccucmp/dump_parse/ffts_parser.py +137 -0
  261. msprobe/msaccucmp/dump_parse/mapping.py +62 -0
  262. msprobe/msaccucmp/dump_parse/nano_dump_data.py +392 -0
  263. msprobe/msaccucmp/dump_parse/proto_dump_data.py +308 -0
  264. msprobe/msaccucmp/dump_parser.py +90 -0
  265. msprobe/msaccucmp/format_manager/__init__.py +16 -0
  266. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NCHW.py +53 -0
  267. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_ND.py +52 -0
  268. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_NZ_to_NHWC.py +53 -0
  269. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_HWCN.py +47 -0
  270. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_FRACTAL_Z_to_NCHW.py +47 -0
  271. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_FRACTAL_Z.py +89 -0
  272. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NCHW.py +37 -0
  273. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_HWCN_to_NHWC.py +37 -0
  274. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_HWCN.py +43 -0
  275. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NCHW.py +48 -0
  276. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NC1HWC0_to_NHWC.py +43 -0
  277. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_FRACTAL_Z.py +87 -0
  278. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NCHW_to_NHWC.py +37 -0
  279. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_NCDHW.py +48 -0
  280. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NDC1HWC0_to_ND.py +44 -0
  281. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_FRACTAL_Z.py +87 -0
  282. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_HWCN.py +37 -0
  283. msprobe/msaccucmp/format_manager/builtin_format_convert/convert_NHWC_to_NCHW.py +37 -0
  284. msprobe/msaccucmp/format_manager/format_manager.py +307 -0
  285. msprobe/msaccucmp/inplace_layer_process.py +186 -0
  286. msprobe/msaccucmp/msaccucmp.py +532 -0
  287. msprobe/msaccucmp/mscmp_advisor.py +128 -0
  288. msprobe/msaccucmp/overflow/__init__.py +16 -0
  289. msprobe/msaccucmp/overflow/overflow_analyse.py +305 -0
  290. msprobe/msaccucmp/overflow/overflow_detection.py +143 -0
  291. msprobe/msaccucmp/pytorch_cmp/__init__.py +16 -0
  292. msprobe/msaccucmp/pytorch_cmp/compare_pytorch.py +389 -0
  293. msprobe/msaccucmp/pytorch_cmp/hdf5_parser.py +377 -0
  294. msprobe/msaccucmp/pytorch_cmp/pytorch_dump_data.py +461 -0
  295. msprobe/msaccucmp/shape_conversion.py +41 -0
  296. msprobe/msaccucmp/vector_cmp/__init__.py +16 -0
  297. msprobe/msaccucmp/vector_cmp/batch_compare.py +197 -0
  298. msprobe/msaccucmp/vector_cmp/compare_detail/__init__.py +16 -0
  299. msprobe/msaccucmp/vector_cmp/compare_detail/compare_detail.py +245 -0
  300. msprobe/msaccucmp/vector_cmp/compare_detail/detail.py +182 -0
  301. msprobe/msaccucmp/vector_cmp/compare_detail/detail_writer.py +580 -0
  302. msprobe/msaccucmp/vector_cmp/fusion_manager/__init__.py +16 -0
  303. msprobe/msaccucmp/vector_cmp/fusion_manager/compare_fusion_op.py +588 -0
  304. msprobe/msaccucmp/vector_cmp/fusion_manager/compare_npu_vs_npu.py +339 -0
  305. msprobe/msaccucmp/vector_cmp/fusion_manager/compare_result.py +326 -0
  306. msprobe/msaccucmp/vector_cmp/fusion_manager/compare_rule.py +156 -0
  307. msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_op.py +204 -0
  308. msprobe/msaccucmp/vector_cmp/fusion_manager/fusion_rule_parser.py +635 -0
  309. msprobe/msaccucmp/vector_cmp/fusion_manager/quant_filter.py +187 -0
  310. msprobe/msaccucmp/vector_cmp/range_manager/__init__.py +16 -0
  311. msprobe/msaccucmp/vector_cmp/range_manager/range_manager.py +100 -0
  312. msprobe/msaccucmp/vector_cmp/range_manager/range_mode.py +94 -0
  313. msprobe/msaccucmp/vector_cmp/range_manager/select_mode.py +86 -0
  314. msprobe/msaccucmp/vector_cmp/vector_comparison.py +535 -0
  315. msprobe/msprobe.py +101 -130
  316. msprobe/overflow_check/__init__.py +15 -0
  317. msprobe/{nan_analyze → overflow_check}/analyzer.py +38 -27
  318. msprobe/{nan_analyze → overflow_check}/graph.py +28 -27
  319. msprobe/{nan_analyze → overflow_check}/utils.py +15 -14
  320. msprobe/pytorch/__init__.py +20 -14
  321. msprobe/pytorch/aclgraph_dump/__init__.py +45 -0
  322. msprobe/pytorch/aclgraph_dump/_meta.py +26 -0
  323. msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut.py → acc_check/acc_check.py} +50 -45
  324. msprobe/pytorch/api_accuracy_checker/{run_ut/run_ut_utils.py → acc_check/acc_check_utils.py} +201 -30
  325. msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/data_generate.py +56 -16
  326. msprobe/pytorch/api_accuracy_checker/{run_ut/multi_run_ut.py → acc_check/multi_acc_check.py} +32 -47
  327. msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/run_overflow_check.py +19 -18
  328. msprobe/pytorch/api_accuracy_checker/common/config.py +22 -20
  329. msprobe/pytorch/api_accuracy_checker/common/utils.py +72 -13
  330. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -11
  331. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +23 -14
  332. msprobe/pytorch/api_accuracy_checker/compare/compare.py +45 -32
  333. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +12 -11
  334. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +14 -12
  335. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +14 -12
  336. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +12 -11
  337. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +12 -11
  338. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +21 -19
  339. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +14 -13
  340. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +12 -11
  341. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +60 -11
  342. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +27 -16
  343. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +13 -11
  344. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +39 -18
  345. msprobe/pytorch/bench_functions/__init__.py +12 -11
  346. msprobe/pytorch/bench_functions/apply_adam.py +12 -11
  347. msprobe/pytorch/bench_functions/apply_adam_w.py +12 -11
  348. msprobe/pytorch/bench_functions/confusion_transpose.py +12 -11
  349. msprobe/pytorch/bench_functions/fast_gelu.py +12 -11
  350. msprobe/pytorch/bench_functions/group_norm_silu.py +12 -11
  351. msprobe/pytorch/bench_functions/layer_norm_eval.py +12 -11
  352. msprobe/pytorch/bench_functions/linear.py +12 -11
  353. msprobe/pytorch/bench_functions/matmul_backward.py +12 -11
  354. msprobe/pytorch/bench_functions/mish.py +12 -11
  355. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +12 -11
  356. msprobe/pytorch/bench_functions/npu_fusion_attention.py +12 -11
  357. msprobe/pytorch/bench_functions/rms_norm.py +12 -11
  358. msprobe/pytorch/bench_functions/rotary_mul.py +12 -11
  359. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +12 -11
  360. msprobe/pytorch/bench_functions/sort_v2.py +12 -11
  361. msprobe/pytorch/bench_functions/swiglu.py +12 -11
  362. msprobe/pytorch/common/__init__.py +12 -11
  363. msprobe/pytorch/common/log.py +12 -11
  364. msprobe/pytorch/common/parse_json.py +12 -11
  365. msprobe/pytorch/common/utils.py +52 -19
  366. msprobe/pytorch/compare/distributed_compare.py +13 -13
  367. msprobe/pytorch/compare/match.py +12 -11
  368. msprobe/pytorch/compare/pt_compare.py +14 -20
  369. msprobe/pytorch/compare/pt_diff_analyze.py +12 -11
  370. msprobe/pytorch/compare/utils.py +12 -11
  371. msprobe/pytorch/{hook_module → dump/api_dump}/api_register.py +18 -16
  372. msprobe/pytorch/{hook_module → dump/api_dump}/hook_module.py +14 -13
  373. msprobe/pytorch/{hook_module → dump/api_dump}/pt_hook_manager.py +68 -23
  374. msprobe/pytorch/{hook_module → dump/api_dump}/register_optimizer_hook.py +13 -11
  375. msprobe/pytorch/{hook_module → dump/api_dump}/script_wrapper.py +17 -14
  376. msprobe/pytorch/{hook_module → dump/api_dump}/utils.py +12 -11
  377. msprobe/pytorch/{debugger → dump/debugger}/debugger_config.py +23 -38
  378. msprobe/pytorch/dump/debugger/precision_debugger.py +130 -0
  379. msprobe/pytorch/{function_factory.py → dump/function_factory.py} +12 -11
  380. msprobe/pytorch/dump/module_dump/hook_wrapper.py +17 -13
  381. msprobe/pytorch/dump/module_dump/module_dump.py +16 -15
  382. msprobe/pytorch/dump/module_dump/{module_processer.py → module_processor.py} +54 -42
  383. msprobe/pytorch/dump/pt_config.py +128 -0
  384. msprobe/pytorch/{pytorch_service.py → dump/pytorch_service.py} +22 -21
  385. msprobe/pytorch/monitor/csv2tb.py +13 -11
  386. msprobe/pytorch/monitor/data_writers.py +13 -11
  387. msprobe/pytorch/monitor/distributed/wrap_distributed.py +13 -11
  388. msprobe/pytorch/monitor/features.py +12 -11
  389. msprobe/pytorch/monitor/module_hook.py +67 -59
  390. msprobe/pytorch/monitor/module_metric.py +13 -11
  391. msprobe/pytorch/monitor/optimizer_collect.py +37 -35
  392. msprobe/pytorch/monitor/utils.py +13 -11
  393. msprobe/pytorch/monitor/visualizer.py +12 -11
  394. msprobe/pytorch/torchair_dump/__init__.py +17 -0
  395. msprobe/pytorch/torchair_dump/torchair_dump.py +114 -0
  396. msprobe/scripts/atb/config_example.json +10 -0
  397. msprobe/scripts/atb/load_atb_probe.sh +101 -0
  398. msprobe/scripts/atb/unload_atb_probe.sh +27 -0
  399. msprobe/scripts/build_msaccucmp.sh +186 -0
  400. msprobe/scripts/conf/help.info +6 -0
  401. msprobe/scripts/conf/version.info +3 -0
  402. msprobe/scripts/run_script/common.sh +538 -0
  403. msprobe/scripts/run_script/main_msaccucmp.sh +232 -0
  404. msprobe/visualization/__init__.py +12 -11
  405. msprobe/visualization/builder/__init__.py +12 -11
  406. msprobe/visualization/builder/graph_builder.py +45 -30
  407. msprobe/visualization/builder/graph_merger.py +53 -32
  408. msprobe/visualization/builder/msprobe_adapter.py +34 -44
  409. msprobe/visualization/compare/__init__.py +12 -11
  410. msprobe/visualization/compare/graph_comparator.py +63 -51
  411. msprobe/visualization/compare/mode_adapter.py +28 -113
  412. msprobe/visualization/db_utils.py +133 -22
  413. msprobe/visualization/graph/__init__.py +12 -11
  414. msprobe/visualization/graph/base_node.py +15 -27
  415. msprobe/visualization/graph/distributed_analyzer.py +97 -40
  416. msprobe/visualization/graph/graph.py +14 -16
  417. msprobe/visualization/graph/node_colors.py +34 -31
  418. msprobe/visualization/graph/node_op.py +12 -11
  419. msprobe/visualization/graph_service.py +580 -205
  420. msprobe/visualization/utils.py +278 -31
  421. tb_graph_ascend/secure_build.py +175 -0
  422. tb_graph_ascend/server/__init__.py +15 -0
  423. tb_graph_ascend/server/app/__init__.py +15 -0
  424. tb_graph_ascend/server/app/model/__init__.py +15 -0
  425. tb_graph_ascend/server/app/model/hierarchy.py +348 -0
  426. tb_graph_ascend/server/app/model/layout_hierarchy_model.py +69 -0
  427. tb_graph_ascend/server/app/model/match_nodes_model.py +573 -0
  428. tb_graph_ascend/server/app/repositories/__init__.py +15 -0
  429. tb_graph_ascend/server/app/repositories/graph_repo_base.py +32 -0
  430. tb_graph_ascend/server/app/repositories/graph_repo_db.py +879 -0
  431. tb_graph_ascend/server/app/repositories/graph_repo_vis.py +83 -0
  432. tb_graph_ascend/server/app/service/__init__.py +18 -0
  433. tb_graph_ascend/server/app/service/graph_service_base.py +158 -0
  434. tb_graph_ascend/server/app/service/graph_service_db.py +438 -0
  435. tb_graph_ascend/server/app/service/graph_service_factory.py +54 -0
  436. tb_graph_ascend/server/app/service/graph_service_vis.py +480 -0
  437. tb_graph_ascend/server/app/utils/__init__.py +15 -0
  438. tb_graph_ascend/server/app/utils/constant.py +80 -0
  439. tb_graph_ascend/server/app/utils/file_check_wrapper.py +46 -0
  440. tb_graph_ascend/server/app/utils/global_state.py +95 -0
  441. tb_graph_ascend/server/app/utils/graph_utils.py +661 -0
  442. tb_graph_ascend/server/app/utils/i18n.py +153 -0
  443. tb_graph_ascend/server/app/utils/request_method.py +46 -0
  444. tb_graph_ascend/server/app/views/__init__.py +15 -0
  445. tb_graph_ascend/server/app/views/graph_views.py +304 -0
  446. tb_graph_ascend/server/plugin.py +108 -0
  447. tb_graph_ascend/server/static/index.html +9250 -0
  448. tb_graph_ascend/server/static/index.js +21 -0
  449. tb_graph_ascend/setup.py +57 -0
  450. mindstudio_probe-8.3.3.dist-info/LICENSE +0 -201
  451. mindstudio_probe-8.3.3.dist-info/RECORD +0 -491
  452. mindstudio_probe-8.3.3.dist-info/entry_points.txt +0 -2
  453. mindstudio_probe-8.3.3.dist-info/top_level.txt +0 -1
  454. msprobe/CMakeLists.txt +0 -5
  455. msprobe/README.md +0 -203
  456. msprobe/core/advisor/advisor.py +0 -129
  457. msprobe/core/advisor/advisor_const.py +0 -58
  458. msprobe/core/advisor/advisor_result.py +0 -58
  459. msprobe/core/compare/find_first/data_processor.py +0 -35
  460. msprobe/core/compare/highlight.py +0 -390
  461. msprobe/core/data_dump/data_collector.py +0 -356
  462. msprobe/core/grad_probe/constant.py +0 -90
  463. msprobe/core/grad_probe/grad_compare.py +0 -187
  464. msprobe/core/grad_probe/utils.py +0 -105
  465. msprobe/core/kernel_dump/kernel_config.py +0 -33
  466. msprobe/docs/01.installation.md +0 -250
  467. msprobe/docs/02.config_introduction.md +0 -221
  468. msprobe/docs/03.config_examples.md +0 -281
  469. msprobe/docs/04.kernel_dump_PyTorch.md +0 -73
  470. msprobe/docs/05.data_dump_PyTorch.md +0 -518
  471. msprobe/docs/06.data_dump_MindSpore.md +0 -618
  472. msprobe/docs/07.accuracy_checker_PyTorch.md +0 -310
  473. msprobe/docs/09.accuracy_checker_MindSpore.md +0 -120
  474. msprobe/docs/10.accuracy_compare_PyTorch.md +0 -637
  475. msprobe/docs/11.accuracy_compare_MindSpore.md +0 -769
  476. msprobe/docs/12.overflow_check_PyTorch.md +0 -82
  477. msprobe/docs/13.overflow_check_MindSpore.md +0 -33
  478. msprobe/docs/14.data_parse_PyTorch.md +0 -282
  479. msprobe/docs/15.free_benchmarking_PyTorch.md +0 -169
  480. msprobe/docs/16.free_benchmarking_MindSpore.md +0 -159
  481. msprobe/docs/17.grad_probe.md +0 -205
  482. msprobe/docs/18.online_dispatch.md +0 -89
  483. msprobe/docs/19.monitor.md +0 -753
  484. msprobe/docs/20.monitor_performance_baseline.md +0 -52
  485. msprobe/docs/21.visualization_PyTorch.md +0 -519
  486. msprobe/docs/22.visualization_MindSpore.md +0 -515
  487. msprobe/docs/23.generate_operator_PyTorch.md +0 -107
  488. msprobe/docs/24.code_mapping_Mindspore.md +0 -29
  489. msprobe/docs/25.tool_function_introduction.md +0 -29
  490. msprobe/docs/26.data_dump_PyTorch_baseline.md +0 -48
  491. msprobe/docs/27.dump_json_instruction.md +0 -795
  492. msprobe/docs/28.debugger_save_instruction.md +0 -288
  493. msprobe/docs/28.kernel_dump_MindSpore.md +0 -69
  494. msprobe/docs/29.data_dump_MSAdapter.md +0 -235
  495. msprobe/docs/30.overflow_check_MSAdapter.md +0 -31
  496. msprobe/docs/31.config_check.md +0 -107
  497. msprobe/docs/32.ckpt_compare.md +0 -69
  498. msprobe/docs/33.generate_operator_MindSpore.md +0 -181
  499. msprobe/docs/34.RL_collect.md +0 -101
  500. msprobe/docs/35.nan_analyze.md +0 -73
  501. msprobe/docs/36.calculation_result_change.md +0 -75
  502. msprobe/docs/FAQ.md +0 -232
  503. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +0 -146
  504. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +0 -14
  505. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +0 -33
  506. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +0 -217
  507. msprobe/docs/img/BLOOM-7B_1.png +0 -0
  508. msprobe/docs/img/BLOOM-7B_2.png +0 -0
  509. msprobe/docs/img/BLOOM-7B_3.png +0 -0
  510. msprobe/docs/img/BLOOM-7B_4.png +0 -0
  511. msprobe/docs/img/GPT-3_1.png +0 -0
  512. msprobe/docs/img/GPT-3_2.png +0 -0
  513. msprobe/docs/img/GPT-3_3.png +0 -0
  514. msprobe/docs/img/GPT-3_4.png +0 -0
  515. msprobe/docs/img/GPT-3_5.png +0 -0
  516. msprobe/docs/img/GPT-3_6.png +0 -0
  517. msprobe/docs/img/GPT-3_7.png +0 -0
  518. msprobe/docs/img/GPT-3_8.png +0 -0
  519. msprobe/docs/img/YOLOV5S_1.png +0 -0
  520. msprobe/docs/img/YOLOV5S_2.png +0 -0
  521. msprobe/docs/img/accuracy_checking_details.png +0 -0
  522. msprobe/docs/img/accuracy_checking_result.png +0 -0
  523. msprobe/docs/img/api_precision_compare_details.png +0 -0
  524. msprobe/docs/img/api_precision_compare_result.png +0 -0
  525. msprobe/docs/img/auto_analyze_log.png +0 -0
  526. msprobe/docs/img/compare_result.png +0 -0
  527. msprobe/docs/img/compare_result_pkl.png +0 -0
  528. msprobe/docs/img/compare_result_pkl_md5.png.png +0 -0
  529. msprobe/docs/img/cpu_info.png +0 -0
  530. msprobe/docs/img/free_benchmark.png +0 -0
  531. msprobe/docs/img/free_benchmark_framework.png +0 -0
  532. msprobe/docs/img/grad_probe_image-1.png +0 -0
  533. msprobe/docs/img/grad_probe_image-2.png +0 -0
  534. msprobe/docs/img/grad_probe_image-3.png +0 -0
  535. msprobe/docs/img/grad_probe_image-4.png +0 -0
  536. msprobe/docs/img/grad_probe_image.png +0 -0
  537. msprobe/docs/img/merge_result.png +0 -0
  538. msprobe/docs/img/module_compare.png +0 -0
  539. msprobe/docs/img/monitor/cpu_info.png +0 -0
  540. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  541. msprobe/docs/img/ms_dump.png +0 -0
  542. msprobe/docs/img/ms_layer.png +0 -0
  543. msprobe/docs/img/pt_dump.png +0 -0
  544. msprobe/docs/img/save_compare_result_sample.png +0 -0
  545. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  546. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  547. msprobe/docs/img/visualization/proxy.png +0 -0
  548. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  549. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  550. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  551. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  552. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  553. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  554. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  555. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  556. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  557. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  558. msprobe/docs/visualization/GPTModel.png +0 -0
  559. msprobe/docs/visualization/ParallelMLP.png +0 -0
  560. msprobe/docs/visualization/layer_mapping_example.md +0 -132
  561. msprobe/docs/visualization/mapping.png +0 -0
  562. msprobe/docs/visualization/mapping1.png +0 -0
  563. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  564. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  565. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  566. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  567. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  568. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  569. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  570. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +0 -59
  571. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  572. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  573. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +0 -80
  574. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  575. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  576. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +0 -330
  577. msprobe/docs/visualization/module_name.png +0 -0
  578. msprobe/docs/visualization/module_name1.png +0 -0
  579. msprobe/docs/visualization/no_mapping.png +0 -0
  580. msprobe/docs/visualization/no_mapping1.png +0 -0
  581. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  582. msprobe/docs/visualization/top_layer.png +0 -0
  583. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +0 -460
  584. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +0 -2081
  585. msprobe/mindspore/code_mapping/bind.py +0 -283
  586. msprobe/mindspore/code_mapping/cmd_parser.py +0 -40
  587. msprobe/mindspore/code_mapping/graph.py +0 -49
  588. msprobe/mindspore/code_mapping/graph_parser.py +0 -211
  589. msprobe/mindspore/code_mapping/main.py +0 -24
  590. msprobe/mindspore/code_mapping/processor.py +0 -34
  591. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +0 -111
  592. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -52
  593. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +0 -257
  594. msprobe/mindspore/free_benchmark/common/config.py +0 -27
  595. msprobe/mindspore/free_benchmark/common/handler_params.py +0 -31
  596. msprobe/mindspore/free_benchmark/common/utils.py +0 -100
  597. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -638
  598. msprobe/mindspore/free_benchmark/handler/base_handler.py +0 -105
  599. msprobe/mindspore/free_benchmark/handler/check_handler.py +0 -55
  600. msprobe/mindspore/free_benchmark/handler/fix_handler.py +0 -51
  601. msprobe/mindspore/free_benchmark/handler/handler_factory.py +0 -36
  602. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +0 -82
  603. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +0 -45
  604. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +0 -78
  605. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +0 -77
  606. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +0 -56
  607. msprobe/mindspore/free_benchmark/perturbation/no_change.py +0 -27
  608. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +0 -46
  609. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +0 -51
  610. msprobe/mindspore/grad_probe/global_context.py +0 -127
  611. msprobe/mindspore/grad_probe/grad_analyzer.py +0 -260
  612. msprobe/mindspore/grad_probe/grad_monitor.py +0 -42
  613. msprobe/mindspore/grad_probe/grad_stat_csv.py +0 -161
  614. msprobe/mindspore/grad_probe/hook.py +0 -115
  615. msprobe/mindspore/grad_probe/utils.py +0 -43
  616. msprobe/mindspore/mindtorch/__init__.py +0 -18
  617. msprobe/mindspore/ms_config.py +0 -153
  618. msprobe/mindspore/task_handler_factory.py +0 -44
  619. msprobe/nan_analyze/__init__.py +0 -14
  620. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +0 -9
  621. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +0 -480
  622. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +0 -567
  623. msprobe/pytorch/debugger/precision_debugger.py +0 -181
  624. msprobe/pytorch/free_benchmark/__init__.py +0 -23
  625. msprobe/pytorch/free_benchmark/common/constant.py +0 -85
  626. msprobe/pytorch/free_benchmark/common/counter.py +0 -87
  627. msprobe/pytorch/free_benchmark/common/enums.py +0 -80
  628. msprobe/pytorch/free_benchmark/common/params.py +0 -152
  629. msprobe/pytorch/free_benchmark/common/utils.py +0 -143
  630. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -215
  631. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +0 -121
  632. msprobe/pytorch/free_benchmark/main.py +0 -123
  633. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +0 -28
  634. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +0 -56
  635. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +0 -107
  636. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +0 -121
  637. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +0 -89
  638. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +0 -87
  639. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +0 -43
  640. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +0 -60
  641. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +0 -34
  642. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +0 -252
  643. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +0 -54
  644. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +0 -40
  645. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -45
  646. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -181
  647. msprobe/pytorch/grad_probe/__init__.py +0 -0
  648. msprobe/pytorch/grad_probe/grad_monitor.py +0 -108
  649. msprobe/pytorch/grad_probe/grad_stat_csv.py +0 -160
  650. msprobe/pytorch/hook_module/__init__.py +0 -16
  651. msprobe/pytorch/hook_module/wrap_aten.py +0 -111
  652. msprobe/pytorch/online_dispatch/__init__.py +0 -19
  653. msprobe/pytorch/online_dispatch/compare.py +0 -224
  654. msprobe/pytorch/online_dispatch/dispatch.py +0 -332
  655. msprobe/pytorch/online_dispatch/dump_compare.py +0 -179
  656. msprobe/pytorch/online_dispatch/single_compare.py +0 -412
  657. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +0 -58
  658. msprobe/pytorch/online_dispatch/utils.py +0 -158
  659. msprobe/pytorch/parse_tool/__init__.py +0 -0
  660. msprobe/pytorch/parse_tool/cli.py +0 -31
  661. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  662. msprobe/pytorch/parse_tool/lib/compare.py +0 -253
  663. msprobe/pytorch/parse_tool/lib/config.py +0 -50
  664. msprobe/pytorch/parse_tool/lib/file_desc.py +0 -45
  665. msprobe/pytorch/parse_tool/lib/interactive_cli.py +0 -97
  666. msprobe/pytorch/parse_tool/lib/parse_exception.py +0 -54
  667. msprobe/pytorch/parse_tool/lib/parse_tool.py +0 -161
  668. msprobe/pytorch/parse_tool/lib/utils.py +0 -299
  669. msprobe/pytorch/parse_tool/lib/visualization.py +0 -85
  670. msprobe/pytorch/pt_config.py +0 -299
  671. /msprobe/core/{grad_probe → dump}/__init__.py +0 -0
  672. /msprobe/{mindspore/code_mapping → core/dump/api_dump}/__init__.py +0 -0
  673. /msprobe/{mindspore/debugger → core/dump/data_dump}/__init__.py +0 -0
  674. /msprobe/{mindspore/exception_dump → core/dump/data_dump/data_processor}/__init__.py +0 -0
  675. /msprobe/{mindspore/free_benchmark → core/dump/debugger}/__init__.py +0 -0
  676. /msprobe/{mindspore/free_benchmark/common → core/dump/kernel_dump}/__init__.py +0 -0
  677. /msprobe/mindspore/{free_benchmark/handler → dump/debugger}/__init__.py +0 -0
  678. /msprobe/mindspore/{grad_probe → dump/dump_processor}/__init__.py +0 -0
  679. /msprobe/mindspore/{overflow_check → dump/exception_dump}/__init__.py +0 -0
  680. /msprobe/mindspore/{mindtorch → dump/mindtorch}/mindtorch_adaptor.py +0 -0
  681. /msprobe/{pytorch/api_accuracy_checker/run_ut → mindspore/dump/overflow_check}/__init__.py +0 -0
  682. /msprobe/{pytorch/debugger → mindspore/monitor}/__init__.py +0 -0
  683. /msprobe/{pytorch/free_benchmark/common → msaccucmp}/__init__.py +0 -0
  684. /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/.keep +0 -0
  685. /msprobe/pytorch/{free_benchmark/perturbed_layers → api_accuracy_checker/acc_check}/__init__.py +0 -0
  686. /msprobe/pytorch/api_accuracy_checker/{run_ut → acc_check}/torch_ut_setting.json +0 -0
  687. /msprobe/pytorch/{free_benchmark/perturbed_layers/npu → dump/api_dump}/__init__.py +0 -0
  688. /msprobe/pytorch/{hook_module → dump/api_dump}/support_wrap_ops.yaml +0 -0
  689. /msprobe/pytorch/{free_benchmark/result_handlers → dump/debugger}/__init__.py +0 -0
@@ -1,36 +1,43 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
1
+ # -------------------------------------------------------------------------
2
+ # This file is part of the MindStudio project.
3
+ # Copyright (c) 2025 Huawei Technologies Co.,Ltd.
3
4
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
5
+ # MindStudio is licensed under Mulan PSL v2.
6
+ # You can use this software according to the terms and conditions of the Mulan PSL v2.
7
+ # You may obtain a copy of Mulan PSL v2 at:
7
8
  #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
+ # http://license.coscl.org.cn/MulanPSL2
9
10
  #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
11
+ # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
12
+ # EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
13
+ # MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
14
+ # See the Mulan PSL v2 for more details.
15
+ # -------------------------------------------------------------------------
15
16
 
16
17
  import os
17
18
  import time
19
+ import threading
18
20
  from copy import deepcopy
19
- from multiprocessing import cpu_count, Pool
21
+ from dataclasses import dataclass
22
+ from multiprocessing import cpu_count, Pool, Manager
23
+ from typing import Callable, Optional
24
+
25
+ from tqdm import tqdm
20
26
  from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
21
27
  check_file_or_directory_path, load_json)
22
28
  from msprobe.core.common.const import FileCheckConst, Const
23
29
  from msprobe.core.common.utils import CompareException, get_dump_mode
24
30
  from msprobe.visualization.compare.graph_comparator import GraphComparator
25
31
  from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \
26
- sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, get_step_or_rank_int
32
+ sort_rank_number_strings, validate_parallel_param, get_step_or_rank_int, \
33
+ monitor_progress, ProgressInfo, calculate_list, get_log_msg_wrapper
27
34
  from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo
28
- from msprobe.core.common.log import logger
35
+ from msprobe.core.common.log import logger, BaseLogger
29
36
  from msprobe.visualization.graph.node_colors import NodeColors
30
37
  from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
31
38
  from msprobe.core.compare.utils import check_and_return_dir_contents
32
39
  from msprobe.core.common.utils import detect_framework_by_dump_json
33
- from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
40
+ from msprobe.visualization.graph.distributed_analyzer import distributed_analyse
34
41
  from msprobe.visualization.builder.graph_merger import GraphMerger
35
42
  from msprobe.visualization.db_utils import post_process_db
36
43
 
@@ -39,12 +46,12 @@ build_output_db_name = f'build_{current_time}.vis.db'
39
46
  compare_output_db_name = f'compare_{current_time}.vis.db'
40
47
 
41
48
 
42
- def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
49
+ def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args, pbar_info=None):
43
50
  dump_path_param = {
44
- 'npu_json_path': graph_n.data_path,
45
- 'bench_json_path': graph_b.data_path,
46
- 'stack_json_path': graph_n.stack_path,
47
- 'is_print_compare_log': input_param.get("is_print_compare_log", True)
51
+ 'npu_path': graph_n.data_path,
52
+ 'bench_path': graph_b.data_path,
53
+ 'stack_path': graph_n.stack_path,
54
+ 'is_print_compare_log': input_param.get("is_print_compare_log", False)
48
55
  }
49
56
  mapping_dict = {}
50
57
  if args.layer_mapping:
@@ -61,19 +68,19 @@ def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
61
68
  raise CompareException(CompareException.CROSS_FRAME_ERROR)
62
69
 
63
70
  graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework,
64
- mapping_dict=mapping_dict)
71
+ mapping_dict=mapping_dict, pbar_info=pbar_info)
65
72
  graph_comparator.compare()
66
73
  return graph_comparator
67
74
 
68
75
 
69
- def _compare_graph_result(input_param, args):
76
+ def _compare_graph_result(input_param, args, pbar_info=None):
70
77
  logger.info('Start building model graphs...')
71
78
  # 对两个数据进行构图
72
- graph_n = _build_graph_info(input_param.get('npu_path'), args)
73
- graph_b = _build_graph_info(input_param.get('bench_path'), args)
79
+ graph_n = _build_graph_info(input_param.get('npu_path'), args, pbar_info=pbar_info)
80
+ graph_b = _build_graph_info(input_param.get('bench_path'), args, pbar_info=pbar_info)
74
81
  logger.info('Model graphs built successfully, start comparing graphs...')
75
82
  # 基于graph、stack和data进行比较
76
- graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
83
+ graph_comparator = _compare_graph(graph_n, graph_b, input_param, args, pbar_info=pbar_info)
77
84
  # 增加micro step标记
78
85
  micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
79
86
  # 开启溢出检测
@@ -84,7 +91,7 @@ def _compare_graph_result(input_param, args):
84
91
  return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
85
92
 
86
93
 
87
- def _export_compare_graph_result(args, result):
94
+ def _export_compare_graph_result(args, result, pbar_info=None):
88
95
  graphs = [result.graph_n, result.graph_b]
89
96
  graph_comparator = result.graph_comparator
90
97
  micro_steps = result.micro_steps
@@ -97,7 +104,7 @@ def _export_compare_graph_result(args, result):
97
104
  args.step_list if hasattr(args, 'step_list') else [0],
98
105
  args.rank_list if hasattr(args, 'rank_list') else [0])
99
106
  try:
100
- GraphBuilder.to_db(output_db_path, export_config)
107
+ GraphBuilder.to_db(output_db_path, export_config, pbar_info=pbar_info)
101
108
  logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_db_path}')
102
109
  return ''
103
110
  except RuntimeError as e:
@@ -105,7 +112,7 @@ def _export_compare_graph_result(args, result):
105
112
  return compare_output_db_name
106
113
 
107
114
 
108
- def _build_graph_info(dump_path, args, graph=None):
115
+ def _build_graph_info(dump_path, args, graph=None, pbar_info=None):
109
116
  construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
110
117
  FileCheckConst.READ_ABLE).common_check()
111
118
  data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
@@ -113,13 +120,13 @@ def _build_graph_info(dump_path, args, graph=None):
113
120
  stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
114
121
  FileCheckConst.READ_ABLE).common_check()
115
122
  if not graph:
116
- graph = GraphBuilder.build(construct_path, data_path, stack_path)
123
+ graph = GraphBuilder.build(construct_path, data_path, stack_path, pbar_info=pbar_info)
117
124
  return GraphInfo(graph, construct_path, data_path, stack_path)
118
125
 
119
126
 
120
- def _build_graph_result(dump_path, args):
127
+ def _build_graph_result(dump_path, args, pbar_info=None):
121
128
  logger.info('Start building model graphs...')
122
- graph = _build_graph_info(dump_path, args).graph
129
+ graph = _build_graph_info(dump_path, args, pbar_info=pbar_info).graph
123
130
  # 增加micro step标记
124
131
  micro_steps = graph.paging_by_micro_step()
125
132
  # 开启溢出检测
@@ -128,30 +135,39 @@ def _build_graph_result(dump_path, args):
128
135
  return BuildGraphResult(graph, micro_steps)
129
136
 
130
137
 
131
- def _run_build_graph_compare(input_param, args, nr, br):
138
+ def _run_build_graph_compare(input_param, args, nr, br, pbar_info=None):
132
139
  logger.info(f'Start building graph for {nr}...')
133
- graph_n = _build_graph_info(input_param.get('npu_path'), args)
134
- graph_b = _build_graph_info(input_param.get('bench_path'), args)
140
+ graph_n = _build_graph_info(input_param.get('npu_path'), args, pbar_info=pbar_info)
141
+ graph_b = _build_graph_info(input_param.get('bench_path'), args, pbar_info=pbar_info)
135
142
  logger.info(f'Building graph for {nr} finished.')
136
143
  return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time)
137
144
 
138
145
 
139
- def _run_build_graph_single(dump_ranks_path, rank, step, args):
146
+ def _run_build_graph_single(dump_ranks_path, rank, step, args, pbar_info=None):
140
147
  logger.info(f'Start building graph for {rank}...')
141
148
  dump_path = os.path.join(dump_ranks_path, rank)
142
- result = _build_graph_result(dump_path, args)
149
+ result = _build_graph_result(dump_path, args, pbar_info=pbar_info)
143
150
  if rank != Const.RANK:
144
151
  result.rank = get_step_or_rank_int(rank, True)
145
152
  logger.info(f'Building graph for step: {step}, rank: {rank} finished.')
146
153
  return result
147
154
 
148
155
 
149
- def _run_graph_compare(graph_task_info, input_param, args):
156
+ def _run_build_graph_and_export(dump_ranks_path, rank, step, args, pbar_info=None):
157
+ result = _run_build_graph_single(dump_ranks_path, rank, step, args, pbar_info)
158
+ if step is not None:
159
+ result.step = get_step_or_rank_int(step)
160
+ create_directory(args.output_path)
161
+
162
+ return _export_build_graph_result(args, result, pbar_info)
163
+
164
+
165
+ def _run_graph_compare(graph_task_info, input_param, args, pbar_info=None):
150
166
  logger.info(f'Start comparing data for {graph_task_info.npu_rank}...')
151
167
  graph_n = graph_task_info.graph_info_n
152
168
  graph_b = graph_task_info.graph_info_b
153
169
  nr = graph_task_info.npu_rank
154
- graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
170
+ graph_comparator = _compare_graph(graph_n, graph_b, input_param, args, pbar_info=pbar_info)
155
171
  micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
156
172
  # 开启溢出检测
157
173
  if args.overflow_check:
@@ -164,7 +180,7 @@ def _run_graph_compare(graph_task_info, input_param, args):
164
180
  return graph_result
165
181
 
166
182
 
167
- def _export_build_graph_result(args, result):
183
+ def _export_build_graph_result(args, result, pbar_info=None):
168
184
  out_path = args.output_path
169
185
  graph = result.graph
170
186
  micro_steps = result.micro_steps
@@ -175,7 +191,7 @@ def _export_build_graph_result(args, result):
175
191
  step=result.step, rank_list=args.rank_list if hasattr(args, 'rank_list') else [0],
176
192
  step_list=args.step_list if hasattr(args, 'step_list') else [0])
177
193
  try:
178
- GraphBuilder.to_db(output_db_path, config)
194
+ GraphBuilder.to_db(output_db_path, config, pbar_info=pbar_info)
179
195
  logger.info(f'Model graph exported successfully, the result file is saved in {output_db_path}')
180
196
  return None
181
197
  except RuntimeError as e:
@@ -189,57 +205,30 @@ def is_real_data_compare(input_param, npu_ranks, bench_ranks):
189
205
  has_real_data = False
190
206
  for nr, br in zip(npu_ranks, bench_ranks):
191
207
  dump_path_param = {
192
- 'npu_json_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE,
193
- FileCheckConst.READ_ABLE).common_check(),
194
- 'bench_json_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE,
195
- FileCheckConst.READ_ABLE).common_check()
208
+ 'npu_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE,
209
+ FileCheckConst.READ_ABLE).common_check(),
210
+ 'bench_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE,
211
+ FileCheckConst.READ_ABLE).common_check()
196
212
  }
197
213
  has_real_data |= get_dump_mode(dump_path_param) == Const.ALL
198
214
  return has_real_data
199
215
 
200
216
 
201
- def _mp_compare(input_param, serializable_args, nr, br):
202
- graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br)
203
- return _run_graph_compare(graph_task_info, input_param, serializable_args)
204
-
205
-
206
- def _compare_graph_ranks(input_param, args, step=None):
207
- with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
208
- def err_call(err):
209
- logger.error(f'Error occurred while comparing graph ranks: {err}')
210
- try:
211
- pool.close()
212
- except OSError as e:
213
- logger.error(f'Error occurred while terminating the pool: {e}')
214
-
215
- serializable_args = SerializableArgs(args)
216
- # 暂存所有rank的graph,用于匹配rank间的分布式节点
217
- compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call)
217
+ def _mp_compare(input_param, serializable_args, nr, br, pbar_info=None):
218
+ graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br, pbar_info=pbar_info)
219
+ return _run_graph_compare(graph_task_info, input_param, serializable_args, pbar_info=pbar_info)
218
220
 
219
- serializable_args.rank_list = [result.rank for result in compare_graph_results]
220
221
 
221
- # 匹配rank间的分布式节点
222
- if len(compare_graph_results) > 1:
223
- DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
224
- args.overflow_check).distributed_match()
225
- DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
226
- args.overflow_check).distributed_match()
222
+ def _mp_compare_and_export(input_param, args, rank, step, pbar_info=None):
223
+ graph_result = _mp_compare(input_param, args, rank, rank, pbar_info=pbar_info)
224
+ if step is not None:
225
+ graph_result.step = get_step_or_rank_int(step)
226
+ create_directory(args.output_path)
227
227
 
228
- export_res_task_list = []
229
- create_directory(args.output_path)
230
- for result in compare_graph_results:
231
- export_res_task_list.append(pool.apply_async(_export_compare_graph_result,
232
- args=(serializable_args, result),
233
- error_callback=err_call))
234
- export_res_list = [res.get() for res in export_res_task_list]
235
- if any(export_res_list):
236
- failed_names = list(filter(lambda x: x, export_res_list))
237
- logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
238
- else:
239
- logger.info('Successfully exported compare graph results.')
228
+ return _export_compare_graph_result(args, graph_result, pbar_info=pbar_info)
240
229
 
241
230
 
242
- def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
231
+ def _compare_graph_ranks(input_param, args, step=None, pbar_info=None):
243
232
  dump_rank_n = input_param.get('npu_path')
244
233
  dump_rank_b = input_param.get('bench_path')
245
234
  npu_ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_rank_n, Const.RANK))
@@ -251,64 +240,86 @@ def _get_compare_graph_results(input_param, serializable_args, step, pool, err_c
251
240
  raise CompareException(CompareException.INVALID_PATH_ERROR)
252
241
  npu_ranks = intersection_ranks
253
242
  bench_ranks = intersection_ranks
254
- compare_graph_results = []
255
- if is_real_data_compare(input_param, npu_ranks, bench_ranks):
256
- mp_task_dict = {}
257
- for nr, br in zip(npu_ranks, bench_ranks):
258
- input_param['npu_path'] = os.path.join(dump_rank_n, nr)
259
- input_param['bench_path'] = os.path.join(dump_rank_b, br)
260
- build_key = f'{step}_{nr}' if step else f'{nr}'
261
- input_param_copy = deepcopy(input_param)
262
- mp_task_dict[build_key] = pool.apply_async(_run_build_graph_compare,
263
- args=(input_param_copy, serializable_args, nr, br),
264
- error_callback=err_call)
265
-
266
- mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
267
- for mp_res in mp_res_dict.values():
268
- compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args))
269
- else:
270
- compare_graph_tasks = []
271
- for nr, br in zip(npu_ranks, bench_ranks):
272
- input_param['npu_path'] = os.path.join(dump_rank_n, nr)
273
- input_param['bench_path'] = os.path.join(dump_rank_b, br)
274
- input_param_copy = deepcopy(input_param)
275
- compare_graph_tasks.append(pool.apply_async(_mp_compare,
276
- args=(input_param_copy, serializable_args, nr, br),
277
- error_callback=err_call))
278
- compare_graph_results = [task.get() for task in compare_graph_tasks]
279
- if step is not None:
280
- for result in compare_graph_results:
281
- result.step = get_step_or_rank_int(step)
282
- return compare_graph_results
243
+ args.rank_list = [get_step_or_rank_int(rank, True) for rank in npu_ranks]
244
+ serializable_args = SerializableArgs(args)
283
245
 
246
+ with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
247
+ def err_call(err):
248
+ logger.error(f'Error occurred while comparing graph ranks: {err}')
284
249
 
285
- def _compare_graph_steps(input_param, args):
286
- dump_step_n = input_param.get('npu_path')
287
- dump_step_b = input_param.get('bench_path')
250
+ if is_real_data_compare(input_param, npu_ranks, bench_ranks):
251
+ # 真实数据模式,考虑到tensor比对过程会使用进程池启用多进程,为了避免嵌套进程池,graph比对使用串行
252
+ compare_graph_results = []
253
+ mp_task_dict = {}
254
+ for nr, br in zip(npu_ranks, bench_ranks):
255
+ input_param['npu_path'] = os.path.join(dump_rank_n, nr)
256
+ input_param['bench_path'] = os.path.join(dump_rank_b, br)
257
+ build_key = f'{step}_{nr}' if step else f'{nr}'
258
+ input_param_copy = deepcopy(input_param)
259
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, nr)
260
+ mp_task_dict[build_key] = pool.apply_async(_run_build_graph_compare,
261
+ args=(input_param_copy, serializable_args, nr, br,
262
+ pbar_info_copy),
263
+ error_callback=err_call)
264
+ mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
265
+ for build_key, mp_res in mp_res_dict.items():
266
+ if pbar_info:
267
+ if Const.REPLACEMENT_CHARACTER in build_key:
268
+ build_key = build_key.split(Const.REPLACEMENT_CHARACTER)[-1]
269
+ pbar_info.task_id = build_key
270
+ compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args, pbar_info))
271
+ if step is not None:
272
+ for result in compare_graph_results:
273
+ result.step = get_step_or_rank_int(step)
274
+
275
+ export_res_task_list = []
276
+ create_directory(args.output_path)
277
+ for result in compare_graph_results:
278
+ export_res_task_list.append(pool.apply_async(_export_compare_graph_result,
279
+ args=(serializable_args, result),
280
+ error_callback=err_call))
281
+ export_res_list = [res.get() for res in export_res_task_list]
282
+ else:
283
+ compare_graph_tasks = []
284
+ for nr, br in zip(npu_ranks, bench_ranks):
285
+ input_param['npu_path'] = os.path.join(dump_rank_n, nr)
286
+ input_param['bench_path'] = os.path.join(dump_rank_b, br)
287
+ input_param_copy = deepcopy(input_param)
288
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, nr)
289
+ compare_graph_tasks.append(pool.apply_async(_mp_compare_and_export,
290
+ args=(input_param_copy, serializable_args, nr, step,
291
+ pbar_info_copy),
292
+ error_callback=err_call))
293
+ export_res_list = [res.get() for res in compare_graph_tasks]
294
+ if any(export_res_list):
295
+ failed_names = list(filter(lambda x: x, export_res_list))
296
+ logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
297
+ else:
298
+ logger.info('Successfully exported compare graph results.')
288
299
 
289
- npu_steps = check_and_return_dir_contents(dump_step_n, Const.STEP)
290
- bench_steps = check_and_return_dir_contents(dump_step_b, Const.STEP)
291
300
 
292
- if npu_steps != bench_steps:
293
- intersection_steps = sort_rank_number_strings(list(set(npu_steps) & set(bench_steps)))
301
+ def _compare_graph_steps(input_param, args, pbar_info=None):
302
+ dump_step_n = input_param.get('npu_path')
303
+ dump_step_b = input_param.get('bench_path')
294
304
 
295
- if not intersection_steps:
296
- logger.error('The steps in the two runs are completely different. Unable to match the steps.')
297
- raise CompareException(CompareException.INVALID_PATH_ERROR)
298
- npu_steps = intersection_steps
305
+ npu_steps = calculate_list(dump_step_n, dump_step_b, Const.STEP)
299
306
 
300
307
  args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
301
308
 
302
- for folder_step in npu_steps:
309
+ for i, folder_step in enumerate(npu_steps):
303
310
  logger.info(f'Start processing data for {folder_step}...')
304
311
  input_param['npu_path'] = os.path.join(dump_step_n, folder_step)
305
312
  input_param['bench_path'] = os.path.join(dump_step_b, folder_step)
306
313
 
307
- _compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \
308
- else _compare_graph_ranks_parallel(input_param, args, step=folder_step)
314
+ if pbar_info:
315
+ pbar_info.step = i
309
316
 
317
+ _compare_graph_ranks(input_param, args, step=folder_step, pbar_info=pbar_info) if not args.parallel_merge \
318
+ else _compare_graph_ranks_parallel(input_param, args, step=folder_step, pbar_info=pbar_info)
310
319
 
311
- def _build_graph_ranks(dump_ranks_path, args, step=None):
320
+
321
+ def _build_graph_ranks_parallel(args, step=None, pbar_info=None):
322
+ dump_ranks_path = os.path.join(args.target_path, step) if step is not None else args.target_path
312
323
  ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
313
324
  serializable_args = SerializableArgs(args)
314
325
  with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
@@ -320,10 +331,13 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
320
331
  logger.error(f'Error occurred while terminating the pool: {e}')
321
332
 
322
333
  build_graph_tasks = []
334
+ if pbar_info and pbar_info.step:
335
+ PbarInfo.reset_progress_and_current_stage(pbar_info, ranks)
323
336
  for rank in ranks:
337
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, rank)
324
338
  build_graph_tasks.append(pool.apply_async(_run_build_graph_single,
325
- args=(dump_ranks_path, rank, step, serializable_args),
326
- error_callback=err_call))
339
+ args=(dump_ranks_path, rank, step, serializable_args,
340
+ pbar_info_copy), error_callback=err_call))
327
341
  build_graph_results = [task.get() for task in build_graph_tasks]
328
342
 
329
343
  if step is not None:
@@ -332,18 +346,19 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
332
346
 
333
347
  if args.parallel_params:
334
348
  validate_parallel_param(args.parallel_params[0], dump_ranks_path)
335
- build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph()
336
-
337
- if len(build_graph_results) > 1 and not args.parallel_merge:
338
- DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
339
- args.overflow_check).distributed_match()
349
+ build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0],
350
+ pbar_info=pbar_info).merge_graph()
351
+ if pbar_info:
352
+ PbarInfo.del_progress_dict_item(pbar_info, ranks,
353
+ [f'{Const.RANK}{result.rank}' for result in build_graph_results])
340
354
 
341
355
  create_directory(args.output_path)
342
356
  export_build_graph_tasks = []
343
357
  serializable_args.rank_list = [result.rank for result in build_graph_results]
344
358
  for result in build_graph_results:
359
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, f'{Const.RANK}{result.rank}')
345
360
  export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result,
346
- args=(serializable_args, result),
361
+ args=(serializable_args, result, pbar_info_copy),
347
362
  error_callback=err_call))
348
363
  export_build_graph_result = [task.get() for task in export_build_graph_tasks]
349
364
  if any(export_build_graph_result):
@@ -353,30 +368,55 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
353
368
  logger.info(f'Successfully exported build graph results.')
354
369
 
355
370
 
356
- def _build_graph_steps(dump_steps_path, args):
357
- steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
371
+ def _build_graph_ranks(args, step=None, pbar_info=None):
372
+ dump_ranks_path = os.path.join(args.target_path, step) if step is not None else args.target_path
373
+ ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
374
+ args.rank_list = [get_step_or_rank_int(rank, True) for rank in ranks]
375
+ serializable_args = SerializableArgs(args)
376
+ with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
377
+ def err_call(err):
378
+ logger.error(f'Error occurred while comparing graph ranks: {err}')
379
+
380
+ tasks = []
381
+ for rank in ranks:
382
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, rank)
383
+ tasks.append(pool.apply_async(_run_build_graph_and_export,
384
+ args=(dump_ranks_path, rank, step, serializable_args, pbar_info_copy),
385
+ error_callback=err_call))
386
+ results = [task.get() for task in tasks]
387
+ if any(results):
388
+ failed_names = list(filter(lambda x: x, results))
389
+ logger.error(f'Unable to export build graph results: {failed_names}.')
390
+ else:
391
+ logger.info(f'Successfully exported build graph results.')
392
+
393
+
394
+ def _build_graph_steps(args, pbar_info=None):
395
+ steps = sorted(check_and_return_dir_contents(args.target_path, Const.STEP))
358
396
  args.step_list = sorted([get_step_or_rank_int(step) for step in steps])
359
397
 
360
- for step in steps:
398
+ for i, step in enumerate(steps):
361
399
  logger.info(f'Start processing data for {step}...')
362
- dump_ranks_path = os.path.join(dump_steps_path, step)
363
- _build_graph_ranks(dump_ranks_path, args, step)
400
+ if pbar_info:
401
+ pbar_info.step = i
402
+ _build_graph_ranks(args, step, pbar_info=pbar_info) if not args.parallel_merge \
403
+ else _build_graph_ranks_parallel(args, step, pbar_info=pbar_info)
364
404
 
365
405
 
366
- def _compare_and_export_graph(graph_task_info, input_param, args, step=None):
367
- result = _run_graph_compare(graph_task_info, input_param, args)
406
+ def _compare_and_export_graph(graph_task_info, input_param, args, step=None, pbar_info=None):
407
+ result = _run_graph_compare(graph_task_info, input_param, args, pbar_info=pbar_info)
368
408
  if step is not None:
369
409
  result.step = get_step_or_rank_int(step)
370
- return _export_compare_graph_result(args, result)
410
+ return _export_compare_graph_result(args, result, pbar_info=pbar_info)
371
411
 
372
412
 
373
- def _compare_graph_ranks_parallel(input_param, args, step=None):
413
+ def _compare_graph_ranks_parallel(input_param, args, step=None, pbar_info=None):
374
414
  args.fuzzy_match = True
375
415
  npu_path = input_param.get('npu_path')
376
416
  bench_path = input_param.get('bench_path')
377
417
  ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK))
378
418
  ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK))
379
- parallel_params = load_parallel_param(input_param)
419
+ parallel_params = args.parallel_params
380
420
  if len(parallel_params) != 2:
381
421
  raise RuntimeError('Parallel params error in compare graph!')
382
422
  validate_parallel_param(parallel_params[0], npu_path)
@@ -394,24 +434,33 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
394
434
  # 1.并行构图
395
435
  build_graph_tasks_n = []
396
436
  build_graph_tasks_b = []
437
+ if pbar_info and pbar_info.step:
438
+ PbarInfo.reset_progress_and_current_stage(pbar_info, list(set(ranks_n) | set(ranks_b)))
397
439
  for rank in ranks_n:
440
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, rank)
398
441
  build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single,
399
- args=(npu_path, rank, step, serializable_args),
442
+ args=(npu_path, rank, step, serializable_args, pbar_info_copy),
400
443
  error_callback=err_call))
401
444
  for rank in ranks_b:
445
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, rank)
402
446
  build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single,
403
- args=(bench_path, rank, step, serializable_args),
404
- error_callback=err_call))
447
+ args=(bench_path, rank, step, serializable_args,
448
+ pbar_info_copy), error_callback=err_call))
405
449
  graph_results_n = [task.get() for task in build_graph_tasks_n]
406
450
  graph_results_b = [task.get() for task in build_graph_tasks_b]
407
451
 
408
452
  # 2.图合并
409
- build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph()
410
- build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph()
453
+ build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0], pbar_info=pbar_info).merge_graph()
454
+ build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True,
455
+ pbar_info=pbar_info).merge_graph()
456
+
411
457
  if len(build_graph_results_n) != len(build_graph_results_b):
412
458
  raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} '
413
459
  f'is inconsistent with that of bench: {len(build_graph_results_b)}!')
414
460
  serializable_args.rank_list = [result.rank for result in build_graph_results_n]
461
+ if pbar_info:
462
+ PbarInfo.del_progress_dict_item(pbar_info, list(set(ranks_n) | set(ranks_b)),
463
+ [f'{Const.RANK}{result.rank}' for result in build_graph_results_n])
415
464
  # 3.并行图比对和输出
416
465
  export_res_task_list = []
417
466
  create_directory(args.output_path)
@@ -422,9 +471,10 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
422
471
  _build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n),
423
472
  _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
424
473
  f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
474
+ pbar_info_copy = PbarInfo.update_task_id(pbar_info, f'{Const.RANK}{result_n.rank}')
425
475
  export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
426
- args=(graph_task_info, input_param, serializable_args, step),
427
- error_callback=err_call))
476
+ args=(graph_task_info, input_param, serializable_args, step,
477
+ pbar_info_copy), error_callback=err_call))
428
478
  export_res_list = [res.get() for res in export_res_task_list]
429
479
  if any(export_res_list):
430
480
  failed_names = list(filter(lambda x: x, export_res_list))
@@ -434,80 +484,336 @@ def _compare_graph_ranks_parallel(input_param, args, step=None):
434
484
 
435
485
 
436
486
  def _graph_service_parser(parser):
437
- parser.add_argument("-i", "--input_path", dest="input_path", type=str,
438
- help="<Required> The compare input path, a dict json.", required=True)
487
+ # -------------------------- 基础必填参数 --------------------------
488
+ parser.add_argument("-tp", "--target_path", dest="target_path", type=str,
489
+ help="<Required> The target path.", required=True)
439
490
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
440
- help="<Required> The compare task result out path.", required=True)
491
+ help="<Required> The visualization task result out path.", required=True)
492
+ # -------------------------- 基础可选参数 --------------------------
493
+ parser.add_argument("-gp", "--golden_path", dest="golden_path", type=str,
494
+ help="<Optional> The golden path.", required=False)
441
495
  parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
442
496
  help="<Optional> The layer mapping file path.", required=False)
443
497
  parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
444
498
  help="<Optional> whether open overflow_check for graph.", required=False)
445
- parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
446
- help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
499
+ parser.add_argument("-fm", "--fuzzy_match", dest="fuzzy_match", action="store_true",
500
+ help="<Optional> whether to perform a fuzzy match on the api name.", required=False)
501
+ parser.add_argument("-tensor_log", "--is_print_compare_log", dest="is_print_compare_log", action="store_true",
502
+ help="<Optional> whether print tensor compare log for visualization task.", required=False)
503
+ parser.add_argument("-progress_log", "--is_print_progress_log", dest="is_print_progress_log", action="store_true",
504
+ help="<Optional> whether print progress log for visualization task.", required=False)
505
+
506
+ # -------------------------- 不同并行切分策略合并可选参数 --------------------------
507
+ group_n = parser.add_argument_group("Parallel Parameters, "
508
+ "used for graph merging under different parallel partitioning strategies")
509
+
510
+ group_n.add_argument("--rank_size", type=int, nargs='+', help="<Optional> The rank size of dump path.",
511
+ required=False)
512
+ group_n.add_argument("--tp", type=int, nargs='+',
513
+ help="<Optional, but required if rank_size is not empty> The tp size of dump path.",
514
+ required=False)
515
+ group_n.add_argument("--pp", type=int, nargs='+',
516
+ help="<Optional, but required if rank_size is not empty> The pp size of dump path.",
517
+ required=False)
518
+ group_n.add_argument("--vpp", type=int, nargs='+', default=[1], help="<Optional> The vpp size of dump path.",
519
+ required=False)
520
+ group_n.add_argument("--order", type=str, nargs='+', default=['tp-cp-ep-dp-pp'],
521
+ help="<Optional> The order of dump path.", required=False)
447
522
 
448
523
 
449
524
  def _graph_service_command(args):
450
- input_param = load_json(args.input_path)
451
- npu_path = input_param.get("npu_path")
452
- bench_path = input_param.get("bench_path")
453
- args.parallel_merge = check_whether_parallel_merge(input_param)
454
- args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None
455
- check_file_or_directory_path(npu_path, isdir=True)
456
- if bench_path:
457
- check_file_or_directory_path(bench_path, isdir=True)
458
- if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
459
- content = check_directory_content(npu_path)
460
- output_db_path = os.path.join(args.output_path, build_output_db_name)
461
- if content == GraphConst.RANKS:
462
- _build_graph_ranks(npu_path, args)
463
- elif content == GraphConst.STEPS:
464
- _build_graph_steps(npu_path, args)
525
+ try:
526
+ if args.is_print_progress_log:
527
+ # 往ProgressInfo中记录error日志,用于前端展示
528
+ BaseLogger.error = get_log_msg_wrapper(BaseLogger.error)
529
+ npu_path = args.target_path
530
+ bench_path = args.golden_path
531
+ ProgressInfo.print_progress_log = args.is_print_progress_log
532
+ args.parallel_merge = True if args.rank_size else False
533
+ args.parallel_params = load_parallel_param(args) if args.parallel_merge else None
534
+ check_file_or_directory_path(npu_path, isdir=True)
535
+ if bench_path:
536
+ check_file_or_directory_path(bench_path, isdir=True)
537
+ if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
538
+ content = check_directory_content(npu_path)
539
+ if content == GraphConst.RANKS:
540
+ _build_graph_ranks_with_pbar(args)
541
+ elif content == GraphConst.STEPS:
542
+ _build_graph_steps_with_pbar(args)
543
+ else:
544
+ _build_graph_with_pbar(npu_path, args)
545
+ elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
546
+ content_n = check_directory_content(npu_path)
547
+ content_b = check_directory_content(bench_path)
548
+ if content_n != content_b:
549
+ raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
550
+ input_param = {
551
+ 'npu_path': args.target_path,
552
+ 'bench_path': args.golden_path,
553
+ 'is_print_compare_log': args.is_print_compare_log
554
+ }
555
+ if content_n == GraphConst.RANKS:
556
+ _compare_graph_ranks_with_pbar(input_param, args)
557
+ elif content_n == GraphConst.STEPS:
558
+ _compare_graph_steps_with_pbar(input_param, args)
559
+ else:
560
+ _compare_graph_with_pbar(input_param, args)
465
561
  else:
466
- result = _build_graph_result(npu_path, args)
467
- create_directory(args.output_path)
468
- file_name = _export_build_graph_result(args, result)
469
- if file_name:
470
- logger.error('Failed to export model build graph.')
471
- elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
472
- content_n = check_directory_content(npu_path)
473
- content_b = check_directory_content(bench_path)
474
- output_db_path = os.path.join(args.output_path, compare_output_db_name)
475
- if content_n != content_b:
476
- raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
477
- if content_n == GraphConst.RANKS:
478
- if args.parallel_merge:
479
- _compare_graph_ranks_parallel(input_param, args)
562
+ logger.error("The npu_path or bench_path should be a folder.")
563
+ raise CompareException(CompareException.INVALID_COMPARE_MODE)
564
+ except KeyboardInterrupt:
565
+ logger.warning("Interrupted by user, terminating processes and cleaning up...")
566
+ except Exception as e:
567
+ logger.error(f"An unexpected error occurred: {e}")
568
+ raise e
569
+ finally:
570
+ ProgressInfo.update_process_running(False)
571
+
572
+
573
+ @dataclass
574
+ class ProgressConfig:
575
+ core_func: Callable
576
+ get_ranks: Callable
577
+ db_name: str
578
+ pbar_info_kwargs: dict = None
579
+ use_monitor_thread: bool = True
580
+ tqdm_total: Optional[int] = None
581
+
582
+
583
+ def _run_with_progress(param, args, config: ProgressConfig):
584
+ """通用进度条处理"""
585
+
586
+ monitor_thread = None
587
+ pbar_info = None
588
+ ranks = None
589
+
590
+ try:
591
+ if config.use_monitor_thread:
592
+ manager = Manager()
593
+ progress_dict = manager.dict()
594
+ pbar_info = PbarInfo(progress_dict=progress_dict, **config.pbar_info_kwargs)
595
+ ranks = config.get_ranks(args)
596
+ else:
597
+ pbar_info = PbarInfo(**config.pbar_info_kwargs)
598
+
599
+ tqdm_args = {
600
+ "desc": GraphConst.PBAR_DESC_PREFIX,
601
+ "total": config.tqdm_total if config.tqdm_total is not None else pbar_info.total,
602
+ "bar_format": GraphConst.BAR_FORMAT
603
+ }
604
+
605
+ with tqdm(**tqdm_args) as pbar:
606
+ # 单进程场景直接更新pbar,多进程场景需要通过monitor thread从共享dict中获取进度更新pbar
607
+ if config.use_monitor_thread:
608
+ monitor_thread = threading.Thread(target=monitor_progress,
609
+ args=(pbar_info, pbar, ranks, args.parallel_merge))
610
+ monitor_thread.start()
480
611
  else:
481
- _compare_graph_ranks(input_param, args)
482
- elif content_n == GraphConst.STEPS:
483
- _compare_graph_steps(input_param, args)
612
+ pbar_info.pbar = pbar
613
+
614
+ if param:
615
+ config.core_func(param, args, pbar_info=pbar_info)
616
+ else:
617
+ config.core_func(args, pbar_info=pbar_info)
618
+
619
+ db_path = os.path.join(args.output_path, config.db_name)
620
+ post_process_db(db_path, pbar_info=pbar_info, is_parallel_merge=args.parallel_merge)
621
+
622
+ if not args.parallel_merge and config.use_monitor_thread:
623
+ distributed_analyse(db_path, args.overflow_check, pbar_info=pbar_info)
624
+
625
+ if config.use_monitor_thread and monitor_thread:
626
+ monitor_thread.join(timeout=5)
627
+
628
+ except KeyboardInterrupt:
629
+ logger.warning("Interrupted by user, terminating processes and cleaning up...")
630
+ except Exception as e:
631
+ logger.error(f"An unexpected error occurred: {e}")
632
+ raise e
633
+ finally:
634
+ ProgressInfo.update_process_running(False)
635
+ if config.use_monitor_thread and pbar_info:
636
+ pbar_info.stop_monitor = True
637
+
638
+
639
+ def _build_graph_ranks_with_pbar(args):
640
+ def core_func(args, pbar_info):
641
+ if args.parallel_merge:
642
+ _build_graph_ranks_parallel(args, pbar_info=pbar_info)
484
643
  else:
485
- result = _compare_graph_result(input_param, args)
486
- create_directory(args.output_path)
487
- file_name = _export_compare_graph_result(args, result)
488
- if file_name:
489
- logger.error('Failed to export model compare graph.')
490
- else:
491
- logger.error("The npu_path or bench_path should be a folder.")
492
- raise CompareException(CompareException.INVALID_COMPARE_MODE)
493
- # 所有数据输出db结束后,添加索引,修改权限
494
- post_process_db(output_db_path)
644
+ _build_graph_ranks(args, pbar_info=pbar_info)
495
645
 
646
+ def get_ranks(args):
647
+ return check_and_return_dir_contents(args.target_path, Const.RANK)
496
648
 
497
- def _pt_graph_service_parser(parser):
498
- _graph_service_parser(parser)
649
+ stage_total = _get_parallel_stage_total(args) if args.parallel_merge else GraphConst.BUILD_STAGES_TOTAL
499
650
 
651
+ _run_with_progress(
652
+ param=None,
653
+ args=args,
654
+ config=ProgressConfig(
655
+ core_func=core_func,
656
+ get_ranks=get_ranks,
657
+ pbar_info_kwargs={"stage_total": stage_total},
658
+ db_name=build_output_db_name,
659
+ )
660
+ )
500
661
 
501
- def _pt_graph_service_command(args):
502
- _graph_service_command(args)
503
662
 
663
+ def _build_graph_steps_with_pbar(args):
664
+ steps = check_and_return_dir_contents(args.target_path, Const.STEP)
504
665
 
505
- def _ms_graph_service_parser(parser):
506
- _graph_service_parser(parser)
666
+ def get_ranks(args):
667
+ return check_and_return_dir_contents(os.path.join(args.target_path, steps[0]), Const.RANK)
507
668
 
669
+ stage_total = _get_parallel_stage_total(args, steps) if args.parallel_merge else GraphConst.BUILD_STAGES_TOTAL
508
670
 
509
- def _ms_graph_service_command(args):
510
- _graph_service_command(args)
671
+ _run_with_progress(
672
+ param=None,
673
+ args=args,
674
+ config=ProgressConfig(
675
+ core_func=_build_graph_steps,
676
+ get_ranks=get_ranks,
677
+ pbar_info_kwargs={"step_total": len(steps), "stage_total": stage_total},
678
+ db_name=build_output_db_name,
679
+ )
680
+ )
681
+
682
+
683
+ def _build_graph_with_pbar(npu_path, args):
684
+ def core_func(param, args, pbar_info):
685
+ result = _build_graph_result(param, args, pbar_info)
686
+ create_directory(args.output_path)
687
+ file_name = _export_build_graph_result(args, result, pbar_info)
688
+ if file_name:
689
+ logger.error('Failed to export model build graph.')
690
+
691
+ _run_with_progress(
692
+ param=npu_path,
693
+ args=args,
694
+ config=ProgressConfig(
695
+ core_func=core_func,
696
+ get_ranks=lambda x: None,
697
+ pbar_info_kwargs={},
698
+ db_name=build_output_db_name,
699
+ use_monitor_thread=False,
700
+ tqdm_total=GraphConst.PBAR_TOTAL
701
+ )
702
+ )
703
+
704
+
705
+ def _compare_graph_ranks_with_pbar(input_param, args):
706
+ def core_func(param, args, pbar_info):
707
+ if args.parallel_merge:
708
+ _compare_graph_ranks_parallel(param, args, pbar_info=pbar_info)
709
+ else:
710
+ _compare_graph_ranks(param, args, pbar_info=pbar_info)
711
+
712
+ def get_ranks(args):
713
+ if args.parallel_merge:
714
+ return calculate_list(args.target_path, args.golden_path, mode=GraphConst.UNION)
715
+ return calculate_list(args.target_path, args.golden_path)
716
+
717
+ stage_total = _get_parallel_stage_total(args, is_compare=True) if args.parallel_merge \
718
+ else GraphConst.COMPARE_STAGES_TOTAL
719
+
720
+ _run_with_progress(
721
+ param=input_param,
722
+ args=args,
723
+ config=ProgressConfig(
724
+ core_func=core_func,
725
+ get_ranks=get_ranks,
726
+ pbar_info_kwargs={"stage_total": stage_total},
727
+ db_name=compare_output_db_name
728
+ )
729
+ )
730
+
731
+
732
+ def _compare_graph_steps_with_pbar(input_param, args):
733
+ steps = calculate_list(args.target_path, args.golden_path, Const.STEP)
734
+
735
+ def get_ranks(args):
736
+ rank_path_t = os.path.join(args.target_path, steps[0])
737
+ rank_path_g = os.path.join(args.golden_path, steps[0])
738
+ if args.parallel_merge:
739
+ return calculate_list(rank_path_t, rank_path_g, mode=GraphConst.UNION)
740
+ return calculate_list(rank_path_t, rank_path_g)
741
+
742
+ stage_total = _get_parallel_stage_total(args, steps, is_compare=True) if args.parallel_merge \
743
+ else GraphConst.COMPARE_STAGES_TOTAL
744
+
745
+ _run_with_progress(
746
+ param=input_param,
747
+ args=args,
748
+ config=ProgressConfig(
749
+ core_func=_compare_graph_steps,
750
+ get_ranks=get_ranks,
751
+ pbar_info_kwargs={"stage_total": stage_total, "step_total": len(steps)},
752
+ db_name=compare_output_db_name
753
+ )
754
+ )
755
+
756
+
757
+ def _compare_graph_with_pbar(input_param, args):
758
+ def core_func(param, args, pbar_info):
759
+ result = _compare_graph_result(param, args, pbar_info=pbar_info)
760
+ create_directory(args.output_path)
761
+ file_name = _export_compare_graph_result(args, result, pbar_info=pbar_info)
762
+ if file_name:
763
+ logger.error('Failed to export model compare graph.')
764
+
765
+ _run_with_progress(
766
+ param=input_param,
767
+ args=args,
768
+ config=ProgressConfig(
769
+ core_func=core_func,
770
+ get_ranks=lambda x: None,
771
+ pbar_info_kwargs={"pbar": None, "stage_total": GraphConst.COMPARE_STAGES_TOTAL},
772
+ db_name=compare_output_db_name,
773
+ use_monitor_thread=False,
774
+ tqdm_total=GraphConst.PBAR_TOTAL
775
+ )
776
+ )
777
+
778
+
779
+ def _get_parallel_stage_total(args, steps=None, is_compare=False):
780
+ """
781
+ 获取不同并行切分策略的任务阶段数
782
+ """
783
+ parallel_params = args.parallel_params
784
+ if not is_compare and (not parallel_params or len(parallel_params) != 1):
785
+ raise RuntimeError('Parallel params error in build graph!')
786
+ if is_compare and (not parallel_params or len(parallel_params) != 2):
787
+ raise RuntimeError('Parallel params error in compare graph!')
788
+ target_path = os.path.join(args.target_path, steps[0]) if steps else args.target_path
789
+ validate_parallel_param(parallel_params[0], target_path)
790
+
791
+ if is_compare:
792
+ golden_path = os.path.join(args.golden_path, steps[0]) if steps else args.golden_path
793
+ validate_parallel_param(parallel_params[1], golden_path, '[Bench]')
794
+
795
+ stage_count_map = {
796
+ "TPMerger": lambda param: param.rank_size // param.tp,
797
+ "PPMerger": lambda param: param.rank_size // param.pp,
798
+ "VPPMerger": lambda param: param.rank_size // param.pp,
799
+ "TPPPMerger": lambda param: param.rank_size // param.pp + param.rank_size // param.pp // param.tp,
800
+ "FullMerger": lambda param: param.rank_size // param.pp + param.rank_size // param.pp // param.tp,
801
+ "NoParallelMerger": 0
802
+ }
803
+
804
+ def _get_stage_count(parallel_param, merger_name: str) -> int:
805
+ rule = stage_count_map.get(merger_name, 0)
806
+ return rule(parallel_param) if callable(rule) else rule
807
+
808
+ merger_name_t = GraphMerger([], parallel_params[0]).strategy.__class__.__name__
809
+ stage_count_target = _get_stage_count(parallel_params[0], merger_name_t)
810
+
811
+ if is_compare:
812
+ merger_name_g = GraphMerger([], parallel_params[1]).strategy.__class__.__name__
813
+ stage_count_golden = _get_stage_count(parallel_params[1], merger_name_g)
814
+ return GraphConst.COMPARE_STAGES_TOTAL + stage_count_target + stage_count_golden
815
+
816
+ return GraphConst.BUILD_STAGES_TOTAL + stage_count_target
511
817
 
512
818
 
513
819
  class CompareGraphResult:
@@ -526,3 +832,72 @@ class BuildGraphResult:
526
832
  self.micro_steps = micro_steps
527
833
  self.rank = rank
528
834
  self.step = step
835
+
836
+
837
+ class PbarInfo:
838
+ def __init__(self, pbar=None, progress_dict=None, task_id=None, step=0, step_total=1,
839
+ stage_total=GraphConst.BUILD_STAGES_TOTAL):
840
+ self.pbar = pbar
841
+ self.progress_dict = progress_dict
842
+ self.task_id = task_id
843
+ self.step = step
844
+ self.step_total = step_total
845
+ self.total = GraphConst.PBAR_TOTAL * step_total
846
+ self.stage_total = stage_total * step_total # 有几个阶段
847
+ self.current_stage_dict = Manager().dict() # 当前阶段,进程共享
848
+ self.stage_progress = round(self.total / self.stage_total, 2) # 每个阶段的最大进度
849
+ self.stop_monitor = False
850
+ self.wait_monitor = False
851
+ self.continue_monitor = True
852
+
853
+ def __deepcopy__(self, memo):
854
+ new_obj = PbarInfo()
855
+ new_obj.progress_dict = self.progress_dict
856
+ new_obj.task_id = self.task_id
857
+ new_obj.step = self.step
858
+ new_obj.step_total = self.step_total
859
+ new_obj.stage_total = self.stage_total
860
+ new_obj.current_stage_dict = self.current_stage_dict
861
+ new_obj.stage_progress = self.stage_progress
862
+ new_obj.total = self.total
863
+ new_obj.stop_monitor = self.stop_monitor
864
+ new_obj.wait_monitor = self.wait_monitor
865
+ new_obj.continue_monitor = self.continue_monitor
866
+ return new_obj
867
+
868
+ @staticmethod
869
+ def update_task_id(pbar_info, task_id):
870
+ """
871
+ 在进程池中,实例作为入参,修改实例属性,需要深拷贝实例使修改生效
872
+ """
873
+ if pbar_info:
874
+ pbar_info.task_id = task_id
875
+ return deepcopy(pbar_info)
876
+ return pbar_info
877
+
878
+ @staticmethod
879
+ def del_progress_dict_item(pbar_info, origin_ranks, merged_ranks):
880
+ """
881
+ 不同并行切分策略的图合并场景下,graph合并到一些rank中,剩余的rank作为task_id不再需要
882
+ """
883
+ diff_ranks = list(set(origin_ranks) - set(merged_ranks))
884
+ for rank in diff_ranks:
885
+ if rank in pbar_info.progress_dict:
886
+ del pbar_info.progress_dict[rank]
887
+
888
+ @staticmethod
889
+ def reset_progress_and_current_stage(pbar_info, task_ids):
890
+ """
891
+ 不同并行切分策略的图合并场景下,每个step需要重置进度信息
892
+ """
893
+ for task_id in task_ids:
894
+ pbar_info.progress_dict[task_id] = GraphConst.PBAR_TOTAL * pbar_info.step
895
+ pbar_info.current_stage_dict[task_id] = pbar_info.stage_total // pbar_info.step_total * pbar_info.step
896
+
897
+ def set_continue_monitor(self, value: bool):
898
+ self.continue_monitor = value
899
+ self.wait_monitor = not value
900
+
901
+ def set_wait_monitor(self, value: bool):
902
+ self.wait_monitor = value
903
+ self.continue_monitor = not value