mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/README.md CHANGED
@@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x}
44
44
 
45
45
  - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。
46
46
  - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。
47
+ - msprobe支持MSAdapter 2.1.0。
47
48
  - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
48
49
 
49
50
 
@@ -53,7 +54,9 @@ export MSPROBE_LOG_LEVEL={x}
53
54
 
54
55
  **2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。**
55
56
 
56
- ## ⚙️ [安装](./docs/01.installation.md)
57
+ ## ⚙️ 安装
58
+
59
+ 请参见[安装指导说明](./docs/01.installation.md)。
57
60
 
58
61
  ## 🌟 新版本特性
59
62
 
@@ -69,35 +72,37 @@ export MSPROBE_LOG_LEVEL={x}
69
72
 
70
73
  ### 1 数据采集
71
74
 
72
- msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作,对应 config.json 中的 task 为 statistics 或 tensor。
75
+ msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics""tensor" task
73
76
 
74
77
  [PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md)
75
78
 
76
79
  [MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md)
77
80
 
81
+ [MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md)
82
+
78
83
  ### 2 精度预检
79
84
 
80
- 精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task 为 run_ut
85
+ 精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
81
86
 
82
87
  PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md)
83
88
 
84
89
  MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
85
90
 
86
- ### 3 精度比对
91
+ ### 3 分级可视化构图比对
87
92
 
88
- 该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。
93
+ 该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
89
94
 
90
- [PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md)
95
+ [PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md)
91
96
 
92
- [MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md)
97
+ [MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
93
98
 
94
- ### 4 溢出检测与解析
99
+ ### 4 精度比对
95
100
 
96
- 溢出检测与解析是在执行精度数据 dump 时,判断是否存在输入正常但输出存在溢出的 API,从而判断是否为正常溢出。对应 config.json 中的 overflow_check。
101
+ 该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。
97
102
 
98
- [PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md)
103
+ [PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md)
99
104
 
100
- [MindSpore 场景的溢出检测与解析](./docs/13.overflow_check_MindSpore.md)
105
+ [MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md)
101
106
 
102
107
  ### 5 数据解析
103
108
 
@@ -129,26 +134,57 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
129
134
 
130
135
  [兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md)
131
136
 
132
- ### 10 分级可视化构图比对
137
+ ### 10 单算子API自动生成脚本
133
138
 
134
- 该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
139
+ 该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
135
140
 
136
- [PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md)
141
+ [PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
137
142
 
138
- [MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
143
+ [MindSpore 单算子API自动生成脚本](./docs/33.generate_operator_MindSpore.md)
139
144
 
145
+ ### 11 数码关联
140
146
 
141
- ### 11 单算子API自动生成脚本
147
+ 该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
142
148
 
143
- 该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
149
+ [MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
144
150
 
145
- [PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
151
+ ### 12 溢出检测与解析
146
152
 
147
- ### 12 数码关联
153
+ 溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。
154
+ 推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。
148
155
 
149
- 该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
156
+ [PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md)
157
+
158
+ [MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md)
159
+
160
+ [MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md)
161
+
162
+ ### 13 训练检查
163
+
164
+ 该工具主要包括:
165
+
166
+ 训练前或精度比对前,对比两个环境下可能影响训练精度的配置差异。
167
+
168
+ [PyTorch 训练前配置检查](./docs/31.config_check.md)
169
+
170
+ 训练过程中或结束后,比较两个不同的checkpoint,评估模型相似度。
171
+
172
+ [checkpoint比对](./docs/32.ckpt_compare.md)
173
+
174
+ ### 14 强化学习数据采集
175
+
176
+ 主要能力:
177
+
178
+ 灵活采集强化学习中重要关键过程数据,并支持比对。
179
+
180
+ [强化学习数据采集](./docs/34.RL_collect.md)
181
+
182
+ ### 15 整网首个溢出节点分析
183
+
184
+ 多rank场景下通过dump数据找到首个出现Nan或Inf的节点。
185
+
186
+ [PyTorch 场景整网首个溢出节点分析](./docs/35.nan_analyze.md)
150
187
 
151
- [MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
152
188
 
153
189
  ## 📑 补充材料
154
190
 
msprobe/core/__init__.py CHANGED
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.single_save.single_saver import SingleSave
17
+ from msprobe.core.single_save.single_comparator import SingleComparator
@@ -51,7 +51,10 @@ class Const:
51
51
  FOUR_SEGMENT = 4
52
52
  SIX_SEGMENT = 6
53
53
  SEVEN_SEGMENT = 7
54
+
54
55
  MAX_DEPTH = 10
56
+ CPU_QUARTER = 4
57
+ DUMP_MAX_DEPTH = 50
55
58
 
56
59
  # dump mode
57
60
  ALL = "all"
@@ -67,7 +70,7 @@ class Const:
67
70
  SUMMARY = "summary"
68
71
  MD5 = "md5"
69
72
  VALUE = "value"
70
- SUMMARY_MODE = [ALL, SUMMARY, MD5]
73
+ SUMMARY_MODE = ["statistics", "md5"]
71
74
 
72
75
  WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
73
76
  WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
@@ -77,6 +80,8 @@ class Const:
77
80
  NUMPY_SUFFIX = ".npy"
78
81
  NUMPY_PATTERN = "*.npy"
79
82
  PT_SUFFIX = ".pt"
83
+ PY_SUFFIX = ".py"
84
+ INIT_PY = "init.py"
80
85
  ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
81
86
  TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
82
87
  ONE_MB = 1048576 # 1 * 1024 * 1024
@@ -92,6 +97,7 @@ class Const:
92
97
  GRAD_OUTPUT = 'grad_output'
93
98
  PARAMS = 'parameters'
94
99
  PARAMS_GRAD = 'parameters_grad'
100
+ DEBUG = 'debug'
95
101
  START = "start"
96
102
  STOP = "stop"
97
103
  ENV_ENABLE = "1"
@@ -129,6 +135,7 @@ class Const:
129
135
  NPU = 'NPU'
130
136
  NPU_LOWERCASE = 'npu'
131
137
  CPU_LOWERCASE = 'cpu'
138
+ GPU_LOWERCASE = 'gpu'
132
139
  CUDA_LOWERCASE = 'cuda'
133
140
  DEVICE = 'device'
134
141
  DISTRIBUTED = 'Distributed'
@@ -137,6 +144,10 @@ class Const:
137
144
  MODULE_PREFIX = ["Module", "Cell"]
138
145
  FORWARD_NAME_SUFFIX = ".forward"
139
146
 
147
+ DUMP_JSON_FILE = "dump_json_file"
148
+ DEBUG_JSON_FILE = "debug_json_file"
149
+ STACK_JSON_FILE = "stack_json_file"
150
+
140
151
  # struct json param
141
152
  ORIGIN_DATA = "origin_data"
142
153
  SCOPE = "scope"
@@ -167,6 +178,10 @@ class Const:
167
178
  TOP_LAYER = "TopLayer"
168
179
  CELL = "Cell"
169
180
  MODULE = "Module"
181
+ API = "api"
182
+ PYNATIVE_MODE = "pynative"
183
+ PYNATIVE_GRAPH_MODE = "pynative_graph"
184
+
170
185
  FRAME_FILE_LIST = ["site-packages/torch", "package/torch", "site-packages/mindspore", "package/mindspore"]
171
186
  INPLACE_LIST = [
172
187
  "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
@@ -188,7 +203,11 @@ class Const:
188
203
 
189
204
  FILL_CHAR_NUMS = 50
190
205
  TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
206
+
191
207
  WITHOUT_CALL_STACK = "The call stack retrieval failed."
208
+ STACK_FILTER_KEYWORDS = ["msprobe/core", "msprobe/pytorch", "msprobe/mindspore"]
209
+ CALL_STACK_FLAG = "data_dump/api_registry"
210
+ NEW_STACK_FLAG = "0"
192
211
 
193
212
  STEP = "step"
194
213
  RANK = "rank"
@@ -206,12 +225,16 @@ class Const:
206
225
  TORCH_FLOAT32 = "torch.float32"
207
226
  TORCH_BFLOAT16 = "torch.bfloat16"
208
227
 
228
+ TYPE = 'type'
209
229
  DTYPE = 'dtype'
210
230
  SHAPE = 'shape'
231
+ STACK_INFO = 'stack_info'
211
232
  MAX = 'Max'
212
233
  MIN = 'Min'
213
234
  MEAN = 'Mean'
214
235
  NORM = 'Norm'
236
+ DATA_NAME = 'data_name'
237
+ TENSOR_STAT_INDEX = 'tensor_stat_index'
215
238
 
216
239
  CODE_STACK = 'Code Stack'
217
240
  OP_NAME = 'Op Name'
@@ -223,6 +246,10 @@ class Const:
223
246
  # 分隔符常量
224
247
  SCOPE_SEPARATOR = "/"
225
248
  REPLACEMENT_CHARACTER = "_"
249
+ PIPE_SEPARATOR = "|"
250
+
251
+ FORWARD_PATTERN = SEP + FORWARD + SEP
252
+ BACKWARD_PATTERN = SEP + BACKWARD + SEP
226
253
 
227
254
  OPTIMIZER = "optimizer"
228
255
  CLIP_GRAD = "clip_grad"
@@ -230,12 +257,136 @@ class Const:
230
257
 
231
258
  TENSOR_STAT_LEN = 2
232
259
 
260
+ SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml"
261
+
262
+ PT_API_TYPE_FUNCTIONAL = "functional"
263
+ PT_API_TYPE_TENSOR = "tensor"
264
+ PT_API_TYPE_TORCH = "torch"
265
+ PT_API_TYPE_VF = "_VF"
266
+ PT_API_TYPE_NPU = "torch_npu"
267
+ PT_API_TYPE_ATEN = "aten"
268
+ PT_API_TYPE_DIST = "distributed"
269
+ PT_API_TYPE_NPU_DIST = "npu_distributed"
270
+ PT_API_TYPE_MINDSPEED = "mindspeed"
271
+
272
+ MS_API_TYPE_OPS = "ops"
273
+ MS_API_TYPE_TENSOR = "tensor"
274
+ MS_API_TYPE_STUB_TENSOR = "stubtensor"
275
+ MS_API_TYPE_MINT = "mint.ops"
276
+ MS_API_TYPE_MINT_FUNC = "mint.nn.functional"
277
+ MS_API_TYPE_COM = "communication.comm_func"
278
+ MS_API_TYPE_MINT_DIST = "mint.distributed"
279
+
280
+ FUNCTIONAL_API_TYPE_PREFIX = "Functional"
281
+ TENSOR_API_TYPE_PREFIX = "Tensor"
282
+ DIST_API_TYPE_PREFIX = "Distributed"
283
+
284
+ TORCH_API_TYPE_PREFIX = "Torch"
285
+ NPU_API_TYPE_PREFIX = "NPU"
286
+ ATEN_API_TYPE_PREFIX = "Aten"
287
+ VF_API_TYPE_PREFIX = "VF"
288
+ MINDSPEED_API_TYPE_PREFIX = "MindSpeed"
289
+
290
+ MINT_API_TYPE_PREFIX = "Mint"
291
+ MINT_FUNC_API_TYPE_PREFIX = "MintFunctional"
292
+ MINT_DIST_API_TYPE_PREFIX = "MintDistributed"
293
+
294
+ SUPPORT_API_DICT_KEY_MAP = {
295
+ PT_FRAMEWORK: {
296
+ PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
297
+ PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
298
+ PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
299
+ PT_API_TYPE_VF: PT_API_TYPE_VF,
300
+ PT_API_TYPE_NPU: PT_API_TYPE_NPU,
301
+ PT_API_TYPE_ATEN: PT_API_TYPE_ATEN,
302
+ PT_API_TYPE_DIST: PT_API_TYPE_DIST,
303
+ PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST,
304
+ PT_API_TYPE_MINDSPEED: PT_API_TYPE_MINDSPEED
305
+ },
306
+ MS_FRAMEWORK: {
307
+ MS_API_TYPE_OPS: MS_API_TYPE_OPS,
308
+ MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR,
309
+ MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR,
310
+ MS_API_TYPE_MINT: MS_API_TYPE_MINT,
311
+ MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC,
312
+ MS_API_TYPE_COM: MS_API_TYPE_COM,
313
+ MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST
314
+ },
315
+ MT_FRAMEWORK: {
316
+ PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
317
+ PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
318
+ PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
319
+ PT_API_TYPE_NPU: PT_API_TYPE_NPU,
320
+ PT_API_TYPE_DIST: PT_API_TYPE_DIST
321
+ }
322
+ }
323
+
324
+ API_DATA_PREFIX = {
325
+ PT_FRAMEWORK: {
326
+ PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
327
+ PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
328
+ PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
329
+ PT_API_TYPE_VF: VF_API_TYPE_PREFIX,
330
+ PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
331
+ PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX,
332
+ PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX,
333
+ PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX,
334
+ PT_API_TYPE_MINDSPEED: MINDSPEED_API_TYPE_PREFIX
335
+ },
336
+ MS_FRAMEWORK: {
337
+ MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX,
338
+ MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
339
+ MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX,
340
+ MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX,
341
+ MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX,
342
+ MS_API_TYPE_COM: DIST_API_TYPE_PREFIX,
343
+ MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX
344
+ },
345
+ MT_FRAMEWORK: {
346
+ PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
347
+ PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
348
+ PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
349
+ PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
350
+ PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX
351
+ }
352
+ }
353
+
354
+ def _fused_adamw_(
355
+ self,
356
+ grads,
357
+ exp_avgs,
358
+ exp_avg_sqs,
359
+ max_exp_avg_sqs,
360
+ state_steps,
361
+ *,
362
+ lr,
363
+ beta1,
364
+ beta2,
365
+ weight_decay,
366
+ eps,
367
+ amsgrad,
368
+ maximize,
369
+ grad_scale=None,
370
+ found_inf=None
371
+ ):
372
+ pass
373
+
374
+ API_WITH_SELF_ARG = {
375
+ 'Torch._fused_adamw_': _fused_adamw_
376
+ }
377
+
378
+ ASCEND = "ASCEND"
379
+ MATCH_MODE_NAME = "pure name"
380
+ MATCH_MODE_MAPPING = "mapping"
381
+ MATCH_MODE_SIMILARITY = "similarity"
382
+
233
383
 
234
384
  class CompareConst:
235
385
  """
236
386
  Class for compare module const
237
387
  """
238
388
  SPACE = " "
389
+ NAME = "Name"
239
390
  # compare result column name
240
391
  NPU_NAME = "NPU Name"
241
392
  BENCH_NAME = "Bench Name"
@@ -256,6 +407,7 @@ class CompareConst:
256
407
  MEAN_DIFF = "Mean diff"
257
408
  NORM_DIFF = "L2norm diff"
258
409
  COSINE = "Cosine"
410
+ EUC_DIST = "EucDist"
259
411
  MAX_ABS_ERR = "MaxAbsErr"
260
412
  MAX_RELATIVE_ERR = "MaxRelativeErr"
261
413
  MIN_RELATIVE_ERR = "MinRelativeErr"
@@ -278,6 +430,7 @@ class CompareConst:
278
430
  OUTPUT_STRUCT = "output_struct"
279
431
  PARAMS_STRUCT = "params_struct"
280
432
  PARAMS_GRAD_STRUCT = "params_grad_struct"
433
+ DEBUG_STRUCT = "debug_struct"
281
434
  SUMMARY = "summary"
282
435
  COMPARE_RESULT = "compare_result"
283
436
  COMPARE_MESSAGE = "compare_message"
@@ -330,8 +483,8 @@ class CompareConst:
330
483
  ULP_ERR_STATUS = "ulp_err_status"
331
484
 
332
485
  COMPARE_RESULT_HEADER = [
333
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
334
- ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
486
+ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST,
487
+ MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
335
488
  NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
336
489
  ]
337
490
 
@@ -357,18 +510,16 @@ class CompareConst:
357
510
  Const.MD5: MD5_COMPARE_RESULT_HEADER
358
511
  }
359
512
 
360
- ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
513
+ ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
514
+ FIVE_THOUSANDTHS_ERR_RATIO]
361
515
  SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
362
516
  MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
363
517
 
364
518
  # dtype match
365
- MS_TYPE = [
366
- [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
367
- [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
368
- ]
369
- TORCH_TYPE = [
370
- [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
371
- [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
519
+
520
+ DTYPE_MATCH_GROUPS = [
521
+ {Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16},
522
+ {Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16}
372
523
  ]
373
524
 
374
525
  # read_op
@@ -386,16 +537,10 @@ class CompareConst:
386
537
  Const.KWARGS: INPUT_STRUCT,
387
538
  Const.OUTPUT: OUTPUT_STRUCT,
388
539
  Const.PARAMS: PARAMS_STRUCT,
389
- Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT
540
+ Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT,
541
+ Const.DEBUG: DEBUG_STRUCT
390
542
  }
391
543
 
392
- STRUCT_COMPARE_KEY = [
393
- INPUT_STRUCT,
394
- OUTPUT_STRUCT,
395
- PARAMS_STRUCT,
396
- PARAMS_GRAD_STRUCT
397
- ]
398
-
399
544
  # compare standard
400
545
  HUNDRED_RATIO_THRESHOLD = 0.01
401
546
  THOUSAND_RATIO_THRESHOLD = 0.001
@@ -467,22 +612,42 @@ class CompareConst:
467
612
  BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: ''
468
613
  }
469
614
  MS_GRAPH_NPY = {
470
- COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
615
+ COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
471
616
  FIVE_THOUSANDTHS_ERR_RATIO: None
472
617
  }
473
618
  MS_GRAPH_STATISTIC = {
474
619
  MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
475
620
  MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
476
621
  }
622
+
623
+ API_MAPPING_KEYS_TO_COMPARE = [
624
+ ('ms_args', 'pt_args'),
625
+ ('ms_outputs', 'pt_outputs'),
626
+ ('ms_parameters', 'pt_parameters'),
627
+ ('ms_parameters_grad', 'pt_parameters_grad')
628
+ ]
629
+
477
630
  INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
478
631
  KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
479
632
  OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
480
633
  PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP
481
634
  PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP
482
- COMPARE_KEY = 'compare_key'
483
- COMPARE_SHAPE = 'compare_shape'
635
+
636
+ CMP_KEY = 'compare_key'
637
+ CMP_SHAPE = 'compare_shape'
638
+
639
+ OP_NAME_X = 'op_name_x'
640
+ MATCH_RESULT_COLUMNS = [
641
+ OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x',
642
+ CMP_KEY, CMP_SHAPE,
643
+ 'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y',
644
+ ]
645
+
484
646
  INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
485
647
  UNREADABLE = 'unreadable data'
648
+ NPU_DUMP_DATA_DIR = 'npu_dump_data_dir'
649
+ BENCH_DUMP_DATA_DIR = 'bench_dump_data_dir'
650
+ NO_REAL_DATA_FLAG = '-1'
486
651
 
487
652
 
488
653
  class FileCheckConst:
@@ -504,6 +669,8 @@ class FileCheckConst:
504
669
  XLSX_SUFFIX = ".xlsx"
505
670
  YAML_SUFFIX = ".yaml"
506
671
  IR_SUFFIX = ".ir"
672
+ ZIP_SUFFIX = ".zip"
673
+ SHELL_SUFFIX = ".sh"
507
674
  MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
508
675
  MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
509
676
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
@@ -512,6 +679,8 @@ class FileCheckConst:
512
679
  MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
513
680
  MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
514
681
  MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
682
+ MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
683
+ MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
515
684
  COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
516
685
  DIR = "dir"
517
686
  FILE = "file"
@@ -525,7 +694,8 @@ class FileCheckConst:
525
694
  CSV_SUFFIX: MAX_CSV_SIZE,
526
695
  XLSX_SUFFIX: MAX_XLSX_SIZE,
527
696
  YAML_SUFFIX: MAX_YAML_SIZE,
528
- IR_SUFFIX: MAX_IR_SIZE
697
+ IR_SUFFIX: MAX_IR_SIZE,
698
+ ZIP_SUFFIX: MAX_ZIP_SIZE
529
699
  }
530
700
  CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
531
701
 
@@ -538,61 +708,6 @@ class OverflowConst:
538
708
  OVERFLOW_DEBUG_MODE = 1
539
709
 
540
710
 
541
- class MsCompareConst:
542
- # api_info field
543
- MINT = "Mint"
544
- MINT_FUNCTIONAL = "MintFunctional"
545
- TENSOR_API = "Tensor"
546
-
547
- API_NAME_STR_LENGTH = 4
548
- MAX_RECURSION_DEPTH = 20
549
-
550
- # Mindtorch api_info field
551
- MINDTORCH_TENSOR = "Tensor"
552
- MINDTORCH = "Torch"
553
- MINDTORCH_FUNC = "Functional"
554
- MINDTORCH_NPU = "NPU"
555
- MINDTORCH_DIST = "Distributed"
556
-
557
-
558
-
559
- MT_VALID_API_TYPES = [
560
- MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
561
- ]
562
-
563
- TASK_FIELD = "task"
564
- STATISTICS_TASK = "statistics"
565
- FRAMEWORK = "framework"
566
- TENSOR_TASK = "tensor"
567
- DUMP_DATA_DIR_FIELD = "dump_data_dir"
568
- DATA_FIELD = "data"
569
-
570
- # supported api yaml
571
- SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
572
- SUPPORTED_TENSOR_LIST_KEY = "tensor"
573
-
574
- # detail_csv
575
- DETAIL_CSV_API_NAME = "API Name"
576
- DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
577
- DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
578
- DETAIL_CSV_SHAPE = "Shape"
579
- DETAIL_CSV_PASS_STATUS = "Status"
580
- DETAIL_CSV_MESSAGE = "Message"
581
- DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
582
-
583
- # result_csv
584
- RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
585
- RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
586
- RESULT_CSV_FILE_NAME = "accuracy_checking_result"
587
-
588
- EPSILON = 1e-8
589
-
590
- class ProcessStatus:
591
- SUCCESS = "success"
592
- API_NOT_FOUND = "api_not_found"
593
- EXCEPTION_SKIP = "exception_skip"
594
-
595
-
596
711
  class MsgConst:
597
712
  """
598
713
  Class for log messages const
@@ -629,7 +744,16 @@ class MonitorConst:
629
744
  """
630
745
  Class for monitor const
631
746
  """
632
- OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
747
+
748
+ # monitor config set default values
749
+ DEFAULT_GRAD_ACC_STEPS = 1
750
+ DEFAULT_START_ITERATION = 0
751
+ DEFAULT_START_STEP = 0
752
+ DEFAULT_MAX_COLLECT_TIMES = 1e8
753
+ DEFAULT_MIN_COLLECT_TIMES = 0
754
+ DEFAULT_STEP_INTERVAL = 1
755
+
756
+ OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"]
633
757
  MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
634
758
  DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
635
759
  DATABASE = "database"
@@ -641,7 +765,7 @@ class MonitorConst:
641
765
  "DeepSpeedZeroOptimizer_Stage3"
642
766
  )
643
767
  DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
644
- RULE_NAME = ['AnomalyTurbulence']
768
+ RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan']
645
769
 
646
770
  SLICE_SIZE = 20480
647
771
  # used for name
@@ -658,15 +782,16 @@ class MonitorConst:
658
782
  ACTVGRAD = "actv_grad"
659
783
  POST_GRAD = "post_grad"
660
784
  PRE_GRAD = "pre_grad"
785
+ PRE_PARAM = "param_origin"
786
+ POST_PARAM = "param_updated"
661
787
  ACC_GRAD = "acc_grad"
662
788
  PREFIX_POST = "post"
663
789
  PREFIX_PRE = "pre"
664
790
  EXP_AVG = "exp_avg"
665
791
  EXP_AVG_SQ = "exp_avg_sq"
666
- PARAM = "param"
667
792
 
668
793
  CSV_HEADER = ["vpp_stage", "name", "step"]
669
- CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"]
794
+ CSV_HEADER_MICRO_STEP = ["vpp_stage", "name", "step", "micro_step"]
670
795
  OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
671
796
  ANOMALY_JSON = "anomaly.json"
672
797
  ANALYSE_JSON = "anomaly_analyse.json"
@@ -674,3 +799,20 @@ class MonitorConst:
674
799
  CSV = "csv"
675
800
  API = "api"
676
801
  HEADER_NAME = 'name'
802
+ MAX_NDIGITS = 20
803
+
804
+ DEFAULT_STAGE = -1
805
+ FORWARD_STAGE = 0
806
+ BACKWARD_STAGE = 1
807
+ OPTIMIZER_STAGE = 2
808
+ FORWARD_KEY = [ACTV]
809
+ BACKWARD_KEY = [ACTVGRAD, PRE_GRAD, POST_GRAD, ACC_GRAD]
810
+ OPTIMIZER_KEY = [EXP_AVG, EXP_AVG_SQ]
811
+
812
+ TRAIN_STAGE = {}
813
+ for key in FORWARD_KEY:
814
+ TRAIN_STAGE[key] = FORWARD_STAGE
815
+ for key in BACKWARD_KEY:
816
+ TRAIN_STAGE[key] = BACKWARD_STAGE
817
+ for key in OPTIMIZER_KEY:
818
+ TRAIN_STAGE[key] = OPTIMIZER_STAGE