mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/README.md CHANGED
@@ -15,7 +15,7 @@ debugger = PrecisionDebugger(config_path='./config.json')
15
15
  ...
16
16
  debugger.start() # 一般在训练循环开头启动工具
17
17
  ... # 循环体
18
- debugger.stop() # 一般在训练循环末尾结束工具
18
+ debugger.stop() # 一般在训练循环末尾结束工具。必须调用,否则可能导致精度数据落盘不全
19
19
  debugger.step() # 在训练循环的最后需要重置工具,非循环场景不需要
20
20
  ```
21
21
 
@@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x}
44
44
 
45
45
  - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。
46
46
  - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。
47
+ - msprobe支持MSAdapter 2.1.0。
47
48
  - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
48
49
 
49
50
 
@@ -69,35 +70,37 @@ export MSPROBE_LOG_LEVEL={x}
69
70
 
70
71
  ### 1 数据采集
71
72
 
72
- msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作,对应 config.json 中的 task 为 statistics 或 tensor。
73
+ msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics""tensor" task
73
74
 
74
75
  [PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md)
75
76
 
76
77
  [MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md)
77
78
 
79
+ [MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md)
80
+
78
81
  ### 2 精度预检
79
82
 
80
- 精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task 为 run_ut
83
+ 精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
81
84
 
82
85
  PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md)
83
86
 
84
87
  MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
85
88
 
86
- ### 3 精度比对
89
+ ### 3 分级可视化构图比对
87
90
 
88
- 该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。
91
+ 该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
89
92
 
90
- [PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md)
93
+ [PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md)
91
94
 
92
- [MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md)
95
+ [MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
93
96
 
94
- ### 4 溢出检测与解析
97
+ ### 4 精度比对
95
98
 
96
- 溢出检测与解析是在执行精度数据 dump 时,判断是否存在输入正常但输出存在溢出的 API,从而判断是否为正常溢出。对应 config.json 中的 overflow_check。
99
+ 该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。
97
100
 
98
- [PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md)
101
+ [PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md)
99
102
 
100
- [MindSpore 场景的溢出检测与解析](./docs/13.overflow_check_MindSpore.md)
103
+ [MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md)
101
104
 
102
105
  ### 5 数据解析
103
106
 
@@ -127,28 +130,30 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
127
130
 
128
131
  该功能收集和聚合模型训练过程中的网络层,优化器, 通信算子的中间值,帮助诊断模型训练过程中计算, 通信,优化器各部分出现的异常情况。
129
132
 
130
- [PyTorch 场景的训练状态监控](./docs/19.monitor.md)
133
+ [兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md)
131
134
 
132
- ### 10 分级可视化构图比对
135
+ ### 10 单算子API自动生成脚本
133
136
 
134
- 该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
137
+ 该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
135
138
 
136
- [PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md)
139
+ [PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
137
140
 
138
- [MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
141
+ ### 11 数码关联
139
142
 
143
+ 该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
140
144
 
141
- ### 11 单算子API自动生成脚本
145
+ [MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
142
146
 
143
- 该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
147
+ ### 12 溢出检测与解析
144
148
 
145
- [PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
149
+ 溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。
150
+ 推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。
146
151
 
147
- ### 12 数码关联
152
+ [PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md)
148
153
 
149
- 该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
154
+ [MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md)
150
155
 
151
- [MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
156
+ [MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md)
152
157
 
153
158
  ## 📑 补充材料
154
159
 
@@ -51,7 +51,10 @@ class Const:
51
51
  FOUR_SEGMENT = 4
52
52
  SIX_SEGMENT = 6
53
53
  SEVEN_SEGMENT = 7
54
+
54
55
  MAX_DEPTH = 10
56
+ CPU_QUARTER = 4
57
+ DUMP_MAX_DEPTH = 50
55
58
 
56
59
  # dump mode
57
60
  ALL = "all"
@@ -103,14 +106,16 @@ class Const:
103
106
  FREE_BENCHMARK = "free_benchmark"
104
107
  RUN_UT = "run_ut"
105
108
  GRAD_PROBE = "grad_probe"
106
- TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
107
- DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
109
+ STRUCTURE = "structure"
110
+ TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE]
111
+ DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE]
108
112
  DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
109
113
  LEVEL_L0 = "L0"
110
114
  LEVEL_L1 = "L1"
111
115
  LEVEL_L2 = "L2"
112
116
  LEVEL_MIX = "mix"
113
- LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX]
117
+ LEVEL_DEBUG = "debug"
118
+ LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX, LEVEL_DEBUG]
114
119
  ATTR_NAME_PREFIX = "wrap_"
115
120
  ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX)
116
121
  KERNEL_DUMP = "kernel_dump"
@@ -228,6 +233,92 @@ class Const:
228
233
 
229
234
  TENSOR_STAT_LEN = 2
230
235
 
236
+ SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml"
237
+
238
+ PT_API_TYPE_FUNCTIONAL = "functional"
239
+ PT_API_TYPE_TENSOR = "tensor"
240
+ PT_API_TYPE_TORCH = "torch"
241
+ PT_API_TYPE_VF = "_VF"
242
+ PT_API_TYPE_NPU = "torch_npu"
243
+ PT_API_TYPE_ATEN = "aten"
244
+ PT_API_TYPE_DIST = "distributed"
245
+ PT_API_TYPE_NPU_DIST = "npu_distributed"
246
+
247
+ MS_API_TYPE_OPS = "ops"
248
+ MS_API_TYPE_TENSOR = "tensor"
249
+ MS_API_TYPE_STUB_TENSOR = "stubtensor"
250
+ MS_API_TYPE_MINT = "mint.ops"
251
+ MS_API_TYPE_MINT_FUNC = "mint.nn.functional"
252
+ MS_API_TYPE_COM = "communication.comm_func"
253
+
254
+ FUNCTIONAL_API_TYPE_PREFIX = "Functional"
255
+ TENSOR_API_TYPE_PREFIX = "Tensor"
256
+ DIST_API_TYPE_PREFIX = "Distributed"
257
+
258
+ TORCH_API_TYPE_PREFIX = "Torch"
259
+ NPU_API_TYPE_PREFIX = "NPU"
260
+ ATEN_API_TYPE_PREFIX = "Aten"
261
+ VF_API_TYPE_PREFIX = "VF"
262
+
263
+ MINT_API_TYPE_PREFIX = "Mint"
264
+ MINT_FUNC_API_TYPE_PREFIX = "MintFunctional"
265
+
266
+ SUPPORT_API_DICT_KEY_MAP = {
267
+ PT_FRAMEWORK: {
268
+ PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
269
+ PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
270
+ PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
271
+ PT_API_TYPE_VF: PT_API_TYPE_VF,
272
+ PT_API_TYPE_NPU: PT_API_TYPE_NPU,
273
+ PT_API_TYPE_ATEN: PT_API_TYPE_ATEN,
274
+ PT_API_TYPE_DIST: PT_API_TYPE_DIST,
275
+ PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST
276
+ },
277
+ MS_FRAMEWORK: {
278
+ MS_API_TYPE_OPS: MS_API_TYPE_OPS,
279
+ MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR,
280
+ MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR,
281
+ MS_API_TYPE_MINT: MS_API_TYPE_MINT,
282
+ MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC,
283
+ MS_API_TYPE_COM: MS_API_TYPE_COM
284
+ },
285
+ MT_FRAMEWORK: {
286
+ PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
287
+ PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
288
+ PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
289
+ PT_API_TYPE_NPU: PT_API_TYPE_NPU,
290
+ PT_API_TYPE_DIST: PT_API_TYPE_DIST
291
+ }
292
+ }
293
+
294
+ API_DATA_PREFIX = {
295
+ PT_FRAMEWORK: {
296
+ PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
297
+ PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
298
+ PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
299
+ PT_API_TYPE_VF: VF_API_TYPE_PREFIX,
300
+ PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
301
+ PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX,
302
+ PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX,
303
+ PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX
304
+ },
305
+ MS_FRAMEWORK: {
306
+ MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX,
307
+ MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
308
+ MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX,
309
+ MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX,
310
+ MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX,
311
+ MS_API_TYPE_COM: DIST_API_TYPE_PREFIX
312
+ },
313
+ MT_FRAMEWORK: {
314
+ PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
315
+ PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
316
+ PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
317
+ PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
318
+ PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX
319
+ }
320
+ }
321
+
231
322
 
232
323
  class CompareConst:
233
324
  """
@@ -254,6 +345,7 @@ class CompareConst:
254
345
  MEAN_DIFF = "Mean diff"
255
346
  NORM_DIFF = "L2norm diff"
256
347
  COSINE = "Cosine"
348
+ EUC_DIST = "EucDist"
257
349
  MAX_ABS_ERR = "MaxAbsErr"
258
350
  MAX_RELATIVE_ERR = "MaxRelativeErr"
259
351
  MIN_RELATIVE_ERR = "MinRelativeErr"
@@ -328,8 +420,8 @@ class CompareConst:
328
420
  ULP_ERR_STATUS = "ulp_err_status"
329
421
 
330
422
  COMPARE_RESULT_HEADER = [
331
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
332
- ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
423
+ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST,
424
+ MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
333
425
  NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
334
426
  ]
335
427
 
@@ -355,18 +447,16 @@ class CompareConst:
355
447
  Const.MD5: MD5_COMPARE_RESULT_HEADER
356
448
  }
357
449
 
358
- ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
450
+ ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
451
+ FIVE_THOUSANDTHS_ERR_RATIO]
359
452
  SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
360
453
  MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
361
454
 
362
455
  # dtype match
363
- MS_TYPE = [
364
- [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
365
- [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
366
- ]
367
- TORCH_TYPE = [
368
- [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
369
- [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
456
+
457
+ DTYPE_MATCH_GROUPS = [
458
+ {Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16},
459
+ {Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16}
370
460
  ]
371
461
 
372
462
  # read_op
@@ -465,7 +555,7 @@ class CompareConst:
465
555
  BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: ''
466
556
  }
467
557
  MS_GRAPH_NPY = {
468
- COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
558
+ COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
469
559
  FIVE_THOUSANDTHS_ERR_RATIO: None
470
560
  }
471
561
  MS_GRAPH_STATISTIC = {
@@ -536,46 +626,6 @@ class OverflowConst:
536
626
  OVERFLOW_DEBUG_MODE = 1
537
627
 
538
628
 
539
- class MsCompareConst:
540
- # api_info field
541
- MINT = "Mint"
542
- MINT_FUNCTIONAL = "MintFunctional"
543
- TENSOR_API = "Tensor"
544
-
545
- API_NAME_STR_LENGTH = 4
546
-
547
- TASK_FIELD = "task"
548
- STATISTICS_TASK = "statistics"
549
- TENSOR_TASK = "tensor"
550
- DUMP_DATA_DIR_FIELD = "dump_data_dir"
551
- DATA_FIELD = "data"
552
-
553
- # supported api yaml
554
- SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
555
- SUPPORTED_TENSOR_LIST_KEY = "tensor"
556
-
557
- # detail_csv
558
- DETAIL_CSV_API_NAME = "API Name"
559
- DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
560
- DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
561
- DETAIL_CSV_SHAPE = "Shape"
562
- DETAIL_CSV_PASS_STATUS = "Status"
563
- DETAIL_CSV_MESSAGE = "Message"
564
- DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
565
-
566
- # result_csv
567
- RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
568
- RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
569
- RESULT_CSV_FILE_NAME = "accuracy_checking_result"
570
-
571
- EPSILON = 1e-8
572
-
573
- class ProcessStatus:
574
- SUCCESS = "success"
575
- API_NOT_FOUND = "api_not_found"
576
- EXCEPTION_SKIP = "exception_skip"
577
-
578
-
579
629
  class MsgConst:
580
630
  """
581
631
  Class for log messages const
@@ -612,6 +662,15 @@ class MonitorConst:
612
662
  """
613
663
  Class for monitor const
614
664
  """
665
+
666
+ # monitor config set default values
667
+ DEFAULT_GRAD_ACC_STEPS = 1
668
+ DEFAULT_START_ITERATION = 0
669
+ DEFAULT_START_STEP = 0
670
+ DEFAULT_MAX_COLLECT_TIMES = 1e8
671
+ DEFAULT_MIN_COLLECT_TIMES = 0
672
+ DEFAULT_STEP_INTERVAL = 1
673
+
615
674
  OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
616
675
  MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
617
676
  DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
@@ -623,29 +682,39 @@ class MonitorConst:
623
682
  "DeepSpeedZeroOptimizer_Stage1_or_2",
624
683
  "DeepSpeedZeroOptimizer_Stage3"
625
684
  )
685
+ DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
626
686
  RULE_NAME = ['AnomalyTurbulence']
627
687
 
628
688
  SLICE_SIZE = 20480
689
+ # used for name
629
690
  DOT = "."
630
- VPP_SEP = ":"
691
+ NAME_SEP = ":"
692
+ INPUT_GRAD = "input_grad"
693
+ OUTPUT_GRAD = "output_grad"
631
694
  ACTV_IN = "input"
632
695
  ACTV_OUT = "output"
633
696
  ACTVGRAD_IN = "input_grad"
634
697
  ACTVGRAD_OUT = "output_grad"
698
+ # used for tasks
699
+ ACTV = "actv"
700
+ ACTVGRAD = "actv_grad"
635
701
  POST_GRAD = "post_grad"
636
702
  PRE_GRAD = "pre_grad"
637
703
  ACC_GRAD = "acc_grad"
638
704
  PREFIX_POST = "post"
639
705
  PREFIX_PRE = "pre"
640
- OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
641
-
642
706
  EXP_AVG = "exp_avg"
643
- EFXP_AVG_SQ = "efxp_avg_sq"
707
+ EXP_AVG_SQ = "exp_avg_sq"
708
+ PARAM = "param"
644
709
 
710
+ CSV_HEADER = ["vpp_stage", "name", "step"]
711
+ CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"]
712
+ OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
645
713
  ANOMALY_JSON = "anomaly.json"
646
714
  ANALYSE_JSON = "anomaly_analyse.json"
647
715
  TENSORBOARD = "tensorboard"
648
716
  CSV = "csv"
649
717
  API = "api"
650
- OPS_START_INDEX = 3
651
- HEADER_NAME_INDEX = 1
718
+ HEADER_NAME = 'name'
719
+
720
+ MAX_NDIGITS = 20
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from functools import wraps
18
+
19
+ from msprobe.core.common.const import Const
20
+ from msprobe.core.common.exceptions import MsprobeException
21
+ from msprobe.core.common.log import logger
22
+
23
+ # 记录工具函数递归的深度
24
+ recursion_depth = defaultdict(int)
25
+
26
+
27
+ def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
28
+ """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
29
+ def decorator(func):
30
+ @wraps(func)
31
+ def wrapper(*args, **kwargs):
32
+ func_id = id(func)
33
+ recursion_depth[func_id] += 1
34
+ if recursion_depth[func_id] > max_depth:
35
+ msg = f"call {func_info} exceeds the recursion limit."
36
+ logger.error_log_with_exp(
37
+ msg,
38
+ MsprobeException(
39
+ MsprobeException.RECURSION_LIMIT_ERROR, msg
40
+ ),
41
+ )
42
+ try:
43
+ result = func(*args, **kwargs)
44
+ finally:
45
+ recursion_depth[func_id] -= 1
46
+ return result
47
+
48
+ return wrapper
49
+
50
+ return decorator
@@ -28,12 +28,14 @@ class MsprobeException(CodedException):
28
28
  OVERFLOW_NUMS_ERROR = 1
29
29
  RECURSION_LIMIT_ERROR = 2
30
30
  INTERFACE_USAGE_ERROR = 3
31
+ UNSUPPORTED_TYPE_ERROR = 4
31
32
 
32
33
  err_strs = {
33
34
  INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
34
35
  OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
35
36
  RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
36
- INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
37
+ INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ",
38
+ UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: "
37
39
  }
38
40
 
39
41
 
@@ -26,6 +26,7 @@ import yaml
26
26
  import numpy as np
27
27
  import pandas as pd
28
28
 
29
+ from msprobe.core.common.decorator import recursion_depth_decorator
29
30
  from msprobe.core.common.log import logger
30
31
  from msprobe.core.common.exceptions import FileCheckException
31
32
  from msprobe.core.common.const import FileCheckConst
@@ -266,6 +267,7 @@ def make_dir(dir_path):
266
267
  file_check.common_check()
267
268
 
268
269
 
270
+ @recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
269
271
  def create_directory(dir_path):
270
272
  """
271
273
  Function Description:
@@ -332,6 +334,23 @@ def change_mode(path, mode):
332
334
  'Failed to change {} authority. {}'.format(path, str(ex))) from ex
333
335
 
334
336
 
337
+ @recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod')
338
+ def recursive_chmod(path):
339
+ """
340
+ 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750
341
+
342
+ :param path: 要修改权限的目录路径
343
+ """
344
+ for _, dirs, files in os.walk(path):
345
+ for file_name in files:
346
+ file_path = os.path.join(path, file_name)
347
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
348
+ for dir_name in dirs:
349
+ dir_path = os.path.join(path, dir_name)
350
+ change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY)
351
+ recursive_chmod(dir_path)
352
+
353
+
335
354
  def path_len_exceeds_limit(file_path):
336
355
  return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
337
356
  len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
@@ -632,7 +651,7 @@ def os_walk_for_files(path, depth):
632
651
  return res
633
652
 
634
653
 
635
- def check_crt_valid(pem_path):
654
+ def check_crt_valid(pem_path, is_public_key=False):
636
655
  """
637
656
  Check the validity of the SSL certificate.
638
657
 
@@ -641,6 +660,7 @@ def check_crt_valid(pem_path):
641
660
 
642
661
  Parameters:
643
662
  pem_path (str): The file path of the SSL certificate.
663
+ is_public_key (bool): The file is public key or not.
644
664
 
645
665
  Raises:
646
666
  RuntimeError: If the SSL certificate is invalid or expired.
@@ -649,7 +669,10 @@ def check_crt_valid(pem_path):
649
669
  try:
650
670
  with FileOpen(pem_path, "r") as f:
651
671
  pem_data = f.read()
652
- cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
672
+ if is_public_key:
673
+ cert = OpenSSL.crypto.load_publickey(OpenSSL.crypto.FILETYPE_PEM, pem_data)
674
+ else:
675
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
653
676
  pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
654
677
  pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
655
678
  logger.info(f"The SSL certificate passes the verification and the validity period "
@@ -250,5 +250,6 @@ inplace_distributed_op:
250
250
  - all_to_all
251
251
  - all_gather_into_tensor
252
252
  - reduce_scatter_tensor
253
+ - batch_isend_irecv
253
254
 
254
255
 
@@ -18,9 +18,7 @@ import os
18
18
  import re
19
19
  import subprocess
20
20
  import time
21
- from collections import defaultdict
22
21
  from datetime import datetime, timezone
23
- from functools import wraps
24
22
 
25
23
  import numpy as np
26
24
 
@@ -75,6 +73,7 @@ class MsprobeBaseException(Exception):
75
73
  MERGE_COMPARE_RESULT_ERROR = 33
76
74
  NAMES_STRUCTS_MATCH_ERROR = 34
77
75
  INVALID_STATE_ERROR = 35
76
+ INVALID_API_NAME_ERROR = 36
78
77
 
79
78
  def __init__(self, code, error_info: str = ""):
80
79
  super(MsprobeBaseException, self).__init__()
@@ -239,12 +238,18 @@ def md5_find(data):
239
238
  for data_detail in data[key_op][api_info]:
240
239
  if data_detail and 'md5' in data_detail:
241
240
  return True
241
+ if isinstance(data[key_op][api_info], bool):
242
+ continue
242
243
  elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
243
244
  return True
244
245
  return False
245
246
 
246
247
 
247
248
  def detect_framework_by_dump_json(file_path):
249
+ json_data = load_json(file_path)
250
+ framework = json_data.get("framework", None)
251
+ if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]:
252
+ return framework
248
253
  pattern_ms = r'"type":\s*"mindspore'
249
254
  pattern_pt = r'"type":\s*"torch'
250
255
  with FileOpen(file_path, 'r') as file:
@@ -277,7 +282,7 @@ def set_dump_path(input_param):
277
282
  npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
278
283
  bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
279
284
  if not npu_path_valid or not bench_path_valid:
280
- logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
285
+ logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
281
286
  raise CompareException(CompareException.INVALID_PATH_ERROR)
282
287
  input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
283
288
  input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
@@ -303,6 +308,9 @@ def get_dump_mode(input_param):
303
308
  if npu_task == Const.TENSOR:
304
309
  return Const.ALL
305
310
 
311
+ if npu_task == Const.STRUCTURE:
312
+ return Const.STRUCTURE
313
+
306
314
  if npu_task == Const.STATISTICS:
307
315
  npu_md5_compare = md5_find(npu_json_data['data'])
308
316
  bench_md5_compare = md5_find(bench_json_data['data'])
@@ -419,6 +427,15 @@ def get_real_step_or_rank(step_or_rank_input, obj):
419
427
  return real_step_or_rank
420
428
 
421
429
 
430
+ def check_init_step(step):
431
+ if not is_int(step):
432
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
433
+ f"{step} must be an integer")
434
+ if not step >= 0:
435
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
436
+ f"{step} must be greater than or equal to 0")
437
+
438
+
422
439
  def check_seed_all(seed, mode, rm_dropout):
423
440
  if is_int(seed):
424
441
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
@@ -462,37 +479,30 @@ def safe_get_value(container, index, container_name, key=None):
462
479
  raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
463
480
 
464
481
 
465
- # 记录工具函数递归的深度
466
- recursion_depth = defaultdict(int)
467
-
468
-
469
- # 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
470
- def recursion_depth_decorator(func_info):
471
- def decorator(func):
472
- @wraps(func)
473
- def wrapper(*args, **kwargs):
474
- func_id = id(func)
475
- recursion_depth[func_id] += 1
476
- if recursion_depth[func_id] > Const.MAX_DEPTH:
477
- msg = f"call {func_info} exceeds the recursion limit."
478
- logger.error_log_with_exp(
479
- msg,
480
- MsprobeException(
481
- MsprobeException.RECURSION_LIMIT_ERROR, msg
482
- ),
483
- )
484
- try:
485
- result = func(*args, **kwargs)
486
- finally:
487
- recursion_depth[func_id] -= 1
488
- return result
489
-
490
- return wrapper
491
-
492
- return decorator
493
-
494
-
495
482
  def check_str_param(param):
496
483
  if not re.match(Const.REGEX_PREFIX_PATTERN, param):
497
484
  logger.error('The parameter {} contains special characters.'.format(param))
498
485
  raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
486
+
487
+
488
+ class DumpPathAggregation:
489
+ dump_file_path = None
490
+ stack_file_path = None
491
+ construct_file_path = None
492
+ dump_tensor_data_dir = None
493
+ free_benchmark_file_path = None
494
+ debug_file_path = None
495
+
496
+
497
+ def is_save_variable_valid(variable, valid_special_types, depth=0):
498
+ if depth > Const.DUMP_MAX_DEPTH:
499
+ return False
500
+ if isinstance(variable, valid_special_types):
501
+ return True
502
+ elif isinstance(variable, (list, tuple)):
503
+ return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable)
504
+ elif isinstance(variable, dict):
505
+ return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
506
+ for key, value in variable.items())
507
+ else:
508
+ return False