mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/README.md CHANGED
@@ -51,15 +51,21 @@ export MSPROBE_LOG_LEVEL={x}
51
51
 
52
52
  **1. Pytorch 框架下,工具暂不支持 Fully Sharded Data Parallel(FSDP)。**
53
53
 
54
+ **2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。**
55
+
54
56
  ## ⚙️ [安装](./docs/01.installation.md)
55
57
 
58
+ ## 🌟 新版本特性
59
+
60
+ 请参见[特性变更说明](./docs/01.installation.md#特性变更说明)。
61
+
56
62
  ## 🛠️ config.json [介绍](./docs/02.config_introduction.md) 和 [示例](./docs/03.config_examples.md)
57
63
 
58
64
  ## 🧰 主要功能
59
65
 
60
66
  ### 0 用前必看
61
67
 
62
- 使用工具前,建议先浏览[**工具功能模块简介、适用场景和当前版本局限性**](./docs/23.tool_function_introduction.md),了解功能特性。
68
+ 使用工具前,建议先浏览[**工具功能模块简介、适用场景和当前版本局限性**](./docs/25.tool_function_introduction.md),了解功能特性。
63
69
 
64
70
  ### 1 数据采集
65
71
 
@@ -131,29 +137,18 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
131
137
 
132
138
  [MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
133
139
 
134
- ## 🌟 新版本特性
135
140
 
136
- 若查看历史版本特性,请参见[安装](./docs/01.installation.md)。
141
+ ### 11 单算子API自动生成脚本
142
+
143
+ 该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
137
144
 
138
- 【数据采集】
139
- - 支持 config.json 中的 step 传入范围;
140
- - 优化了指定 step 的机制,指定 step 结束后工具不再采集数据,但训练会继续运行。工具结束运行后,日志提示信息如下:
141
- ```bash
142
- ****************************************
143
- * msprobe ends successfully. *
144
- ****************************************
145
- ```
146
- 注:在多卡场景,每张卡进程训练到指定 step 之后都会打印一次上述信息。
145
+ [PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
147
146
 
148
- 【精度预检】
149
- - 在 PyTorch 场景,支持部分 NPU 融合算子预检。
147
+ ### 12 数码关联
150
148
 
151
- 【精度比对】
152
- - 解决了使用 MindSpore 需要安装 PyTorch 的问题。
149
+ 该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
153
150
 
154
- 【无标杆比对】
155
- - 补充在 PyTorch 场景的性能基线报告;
156
- - 支持 MindSpore 场景的 change_value 扰动模式。
151
+ [MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
157
152
 
158
153
  ## 📑 补充材料
159
154
 
msprobe/config.json CHANGED
@@ -5,6 +5,7 @@
5
5
  "step": [],
6
6
  "level": "L1",
7
7
  "enable_dataloader": false,
8
+ "async_dump": false,
8
9
  "tensor": {
9
10
  "scope": [],
10
11
  "list":[],
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,6 +29,7 @@ class Const:
29
29
  SEP = "."
30
30
  REGEX_PREFIX_MAX_LENGTH = 20
31
31
  REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
32
+ REGEX_FORWARD_BACKWARD = r'\.(forward|backward)\.'
32
33
  FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
33
34
  STRING_BLACKLIST = r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]"
34
35
  COMMA = ","
@@ -65,6 +66,7 @@ class Const:
65
66
  ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
66
67
  SUMMARY = "summary"
67
68
  MD5 = "md5"
69
+ VALUE = "value"
68
70
  SUMMARY_MODE = [ALL, SUMMARY, MD5]
69
71
 
70
72
  WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
@@ -73,6 +75,7 @@ class Const:
73
75
 
74
76
  PKL_SUFFIX = ".pkl"
75
77
  NUMPY_SUFFIX = ".npy"
78
+ NUMPY_PATTERN = "*.npy"
76
79
  PT_SUFFIX = ".pt"
77
80
  ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
78
81
  TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
@@ -87,6 +90,8 @@ class Const:
87
90
  INPUT_KWARGS = 'input_kwargs'
88
91
  GRAD_INPUT = 'grad_input'
89
92
  GRAD_OUTPUT = 'grad_output'
93
+ PARAMS = 'parameters'
94
+ PARAMS_GRAD = 'parameters_grad'
90
95
  START = "start"
91
96
  STOP = "stop"
92
97
  ENV_ENABLE = "1"
@@ -112,6 +117,7 @@ class Const:
112
117
  DATA = "data"
113
118
  PT_FRAMEWORK = "pytorch"
114
119
  MS_FRAMEWORK = "mindspore"
120
+ MT_FRAMEWORK = "mindtorch"
115
121
  UNKNOWN_FRAMEWORK = "unknown"
116
122
  DIRECTORY_LENGTH = 4096
117
123
  FILE_NAME_LENGTH = 255
@@ -122,9 +128,12 @@ class Const:
122
128
  NPU_LOWERCASE = 'npu'
123
129
  CPU_LOWERCASE = 'cpu'
124
130
  CUDA_LOWERCASE = 'cuda'
131
+ DEVICE = 'device'
125
132
  DISTRIBUTED = 'Distributed'
126
- DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
133
+ DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
127
134
  "Aten", "VF", "NPU", "Jit"]
135
+ MODULE_PREFIX = ["Module", "Cell"]
136
+ FORWARD_NAME_SUFFIX = ".forward"
128
137
 
129
138
  # struct json param
130
139
  ORIGIN_DATA = "origin_data"
@@ -145,10 +154,13 @@ class Const:
145
154
  SCOPE_ID_INDEX = -1
146
155
  SCOPE_DIRECTION_INDEX = -2
147
156
  TYPE_NAME_INDEX = -3
157
+ PARAMS_GRAD_TYPE_NAME_INDEX = -2
148
158
  LAYER_NAME_INDEX = -4
159
+ PARAMS_GRAD_NAME_INDEX = -3
149
160
  API_TYPE_INDEX = 0
150
161
  LEFT_MOVE_INDEX = -1
151
162
  RIGHT_MOVE_INDEX = 1
163
+ LAST_INDEX = -1
152
164
 
153
165
  TOP_LAYER = "TopLayer"
154
166
  CELL = "Cell"
@@ -162,12 +174,16 @@ class Const:
162
174
 
163
175
  CONVERT = {
164
176
  "int32_to_int64": ["torch.int32", "torch.int64"],
177
+ "int64_to_fp32": ["torch.int64", "torch.float32"]
165
178
  }
166
179
 
167
180
  CONVERT_API = {
168
- "int32_to_int64": ["cross_entropy"]
181
+ "int32_to_int64": ["cross_entropy"],
182
+ "int64_to_fp32": ["histc"]
169
183
  }
170
184
 
185
+ FA_SPECIAL_SPARSE_MODE = [2, 3, 4]
186
+
171
187
  FILL_CHAR_NUMS = 50
172
188
  TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
173
189
  WITHOUT_CALL_STACK = "The call stack retrieval failed."
@@ -179,6 +195,8 @@ class Const:
179
195
  STEP_RANK_MAXIMUM_VALUE = int(1e6)
180
196
 
181
197
  # data type const
198
+ TORCH_INT_DTYPE = ["torch.int8", "torch.int32", "torch.int64"]
199
+ TORCH_FLOAT_DTYPE = ["torch.bfloat16", "torch.float16", "torch.float32", "torch.float64"]
182
200
  FLOAT16 = "Float16"
183
201
  FLOAT32 = "Float32"
184
202
  BFLOAT16 = "BFloat16"
@@ -193,6 +211,23 @@ class Const:
193
211
  MEAN = 'Mean'
194
212
  NORM = 'Norm'
195
213
 
214
+ CODE_STACK = 'Code Stack'
215
+ OP_NAME = 'Op Name'
216
+ SCOPE_NAME = 'Scope Name'
217
+ CODE_STACKS = 'Code Stacks'
218
+ FILE_PATH = 'File Path'
219
+ NEW_LINE = '\n'
220
+ CSV_NEWLINE_SEPARATOR = ',\n'
221
+ # 分隔符常量
222
+ SCOPE_SEPARATOR = "/"
223
+ REPLACEMENT_CHARACTER = "_"
224
+
225
+ OPTIMIZER = "optimizer"
226
+ CLIP_GRAD = "clip_grad"
227
+ END_PREFIX = "end_"
228
+
229
+ TENSOR_STAT_LEN = 2
230
+
196
231
 
197
232
  class CompareConst:
198
233
  """
@@ -239,13 +274,58 @@ class CompareConst:
239
274
  INPUT_STRUCT = "input_struct"
240
275
  KWARGS_STRUCT = "kwargs_struct"
241
276
  OUTPUT_STRUCT = "output_struct"
277
+ PARAMS_STRUCT = "params_struct"
278
+ PARAMS_GRAD_STRUCT = "params_grad_struct"
242
279
  SUMMARY = "summary"
280
+ COMPARE_RESULT = "compare_result"
281
+ COMPARE_MESSAGE = "compare_message"
243
282
  MAX_EXCEL_LENGTH = 1048576
244
283
  YES = "Yes"
245
284
  NO = "No"
246
285
  STATISTICS_INDICATOR_NUM = 4
247
286
  EPSILON = 1e-10
248
287
  COMPARE_ENDS_SUCCESSFULLY = "msprobe compare ends successfully."
288
+ DEFAULT_RATIO_VALUE = 10000
289
+ THOUSANDTH_PASS_VALUE = 0.999
290
+ ZERO_SHAPE = '(0,)'
291
+
292
+ BENCHMARK_COMPARE_ALGORITHM_NAME = "标杆比对法"
293
+ ULP_COMPARE_ALGORITHM_NAME = "ULP误差比对法"
294
+ BINARY_CONSISTENCY_ALGORITHM_NAME = "二进制一致法"
295
+ ABSOLUTE_THRESHOLD_ALGORITHM_NAME = "绝对阈值法"
296
+ THOUSANDTH_STANDARD_ALGORITHM_NAME = "双千指标法"
297
+ ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME = "累积误差比对法"
298
+
299
+ ABSOLUTE_THRESHOLD = 'absolute_threshold'
300
+ BINARY_CONSISTENCY = 'binary_consistency'
301
+ ULP_COMPARE = 'ulp_compare'
302
+ THOUSANDTH_STANDARD = 'thousandth_threshold'
303
+ BENCHMARK = 'benchmark'
304
+ ACCUMULATIVE_ERROR_COMPARE = 'accumulative_error_compare'
305
+
306
+ SMALL_VALUE_ERR_RATIO = "small_value_err_ratio"
307
+ RMSE_RATIO = "rmse_ratio"
308
+ MAX_REL_ERR_RATIO = "max_rel_err_ratio"
309
+ MEAN_REL_ERR_RATIO = "mean_rel_err_ratio"
310
+ EB_RATIO = "eb_ratio"
311
+
312
+ SMALL_VALUE = "small_value"
313
+ RMSE = "rmse"
314
+ MAX_REL_ERR = "max_rel_err"
315
+ MEAN_REL_ERR = "mean_rel_err"
316
+ EB = "eb"
317
+
318
+ SMALL_VALUE_ERR_STATUS = "small_value_err_status"
319
+ RMSE_STATUS = "rmse_status"
320
+ MAX_REL_ERR_STATUS = "max_rel_err_status"
321
+ MEAN_REL_ERR_STATUS = "mean_rel_err_status"
322
+ EB_STATUS = "eb_status"
323
+
324
+ MEAN_ULP_ERR = "mean_ulp_err"
325
+ ULP_ERR_PROPORTION = "ulp_err_proportion"
326
+ ULP_ERR_PROPORTION_RATIO = "ulp_err_proportion_ratio"
327
+
328
+ ULP_ERR_STATUS = "ulp_err_status"
249
329
 
250
330
  COMPARE_RESULT_HEADER = [
251
331
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
@@ -263,12 +343,57 @@ class CompareConst:
263
343
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
264
344
  ]
265
345
 
346
+ COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK]
347
+
348
+ SUMMARY_COMPARE_RESULT_HEADER_STACK = SUMMARY_COMPARE_RESULT_HEADER + [STACK]
349
+
350
+ MD5_COMPARE_RESULT_HEADER_STACK = MD5_COMPARE_RESULT_HEADER + [STACK]
351
+
266
352
  HEAD_OF_COMPARE_MODE = {
267
353
  Const.ALL: COMPARE_RESULT_HEADER,
268
354
  Const.SUMMARY: SUMMARY_COMPARE_RESULT_HEADER,
269
355
  Const.MD5: MD5_COMPARE_RESULT_HEADER
270
356
  }
271
357
 
358
+ ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
359
+ SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
360
+ MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
361
+
362
+ # dtype match
363
+ MS_TYPE = [
364
+ [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
365
+ [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
366
+ ]
367
+ TORCH_TYPE = [
368
+ [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
369
+ [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
370
+ ]
371
+
372
+ # read_op
373
+ IO_NAME_MAPPING = {
374
+ Const.INPUT_ARGS: '.input',
375
+ Const.INPUT_KWARGS: '.input',
376
+ Const.INPUT: '.input',
377
+ Const.OUTPUT: '.output',
378
+ Const.PARAMS: '.parameters'
379
+ }
380
+
381
+ # state to struct mapping
382
+ STATE_TO_STRUCT_MAPPING = {
383
+ Const.INPUT: INPUT_STRUCT,
384
+ Const.KWARGS: INPUT_STRUCT,
385
+ Const.OUTPUT: OUTPUT_STRUCT,
386
+ Const.PARAMS: PARAMS_STRUCT,
387
+ Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT
388
+ }
389
+
390
+ STRUCT_COMPARE_KEY = [
391
+ INPUT_STRUCT,
392
+ OUTPUT_STRUCT,
393
+ PARAMS_STRUCT,
394
+ PARAMS_GRAD_STRUCT
395
+ ]
396
+
272
397
  # compare standard
273
398
  HUNDRED_RATIO_THRESHOLD = 0.01
274
399
  THOUSAND_RATIO_THRESHOLD = 0.001
@@ -350,6 +475,8 @@ class CompareConst:
350
475
  INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
351
476
  KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
352
477
  OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
478
+ PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP
479
+ PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP
353
480
  COMPARE_KEY = 'compare_key'
354
481
  COMPARE_SHAPE = 'compare_shape'
355
482
  INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
@@ -372,13 +499,17 @@ class FileCheckConst:
372
499
  JSON_SUFFIX = ".json"
373
500
  PT_SUFFIX = ".pt"
374
501
  CSV_SUFFIX = ".csv"
502
+ XLSX_SUFFIX = ".xlsx"
375
503
  YAML_SUFFIX = ".yaml"
504
+ IR_SUFFIX = ".ir"
376
505
  MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
377
506
  MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
378
507
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
379
508
  MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
380
509
  MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
510
+ MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
381
511
  MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
512
+ MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
382
513
  COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
383
514
  DIR = "dir"
384
515
  FILE = "file"
@@ -390,7 +521,9 @@ class FileCheckConst:
390
521
  JSON_SUFFIX: MAX_JSON_SIZE,
391
522
  PT_SUFFIX: MAX_PT_SIZE,
392
523
  CSV_SUFFIX: MAX_CSV_SIZE,
393
- YAML_SUFFIX: MAX_YAML_SIZE
524
+ XLSX_SUFFIX: MAX_XLSX_SIZE,
525
+ YAML_SUFFIX: MAX_YAML_SIZE,
526
+ IR_SUFFIX: MAX_IR_SIZE
394
527
  }
395
528
  CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
396
529
 
@@ -437,6 +570,11 @@ class MsCompareConst:
437
570
 
438
571
  EPSILON = 1e-8
439
572
 
573
+ class ProcessStatus:
574
+ SUCCESS = "success"
575
+ API_NOT_FOUND = "api_not_found"
576
+ EXCEPTION_SKIP = "exception_skip"
577
+
440
578
 
441
579
  class MsgConst:
442
580
  """
@@ -474,15 +612,20 @@ class MonitorConst:
474
612
  """
475
613
  Class for monitor const
476
614
  """
477
- OP_LIST = ["min", "max", "norm", "zeros", "nans", "id", "mean"]
615
+ OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
478
616
  MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
479
617
  DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
480
618
  DATABASE = "database"
481
619
  EMAIL = "email"
482
620
  OPT_TY = ['Megatron_DistributedOptimizer', 'Megatron_Float16OptimizerWithFloat16Params']
483
- DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage0", "DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3")
621
+ DEEPSPEED_OPT_TY = (
622
+ "DeepSpeedZeroOptimizer_Stage0",
623
+ "DeepSpeedZeroOptimizer_Stage1_or_2",
624
+ "DeepSpeedZeroOptimizer_Stage3"
625
+ )
484
626
  RULE_NAME = ['AnomalyTurbulence']
485
627
 
628
+ SLICE_SIZE = 20480
486
629
  DOT = "."
487
630
  VPP_SEP = ":"
488
631
  ACTV_IN = "input"
@@ -491,12 +634,18 @@ class MonitorConst:
491
634
  ACTVGRAD_OUT = "output_grad"
492
635
  POST_GRAD = "post_grad"
493
636
  PRE_GRAD = "pre_grad"
637
+ ACC_GRAD = "acc_grad"
494
638
  PREFIX_POST = "post"
495
639
  PREFIX_PRE = "pre"
640
+ OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
496
641
 
642
+ EXP_AVG = "exp_avg"
643
+ EFXP_AVG_SQ = "efxp_avg_sq"
497
644
 
498
645
  ANOMALY_JSON = "anomaly.json"
499
646
  ANALYSE_JSON = "anomaly_analyse.json"
500
647
  TENSORBOARD = "tensorboard"
501
648
  CSV = "csv"
502
649
  API = "api"
650
+ OPS_START_INDEX = 3
651
+ HEADER_NAME_INDEX = 1
@@ -27,11 +27,13 @@ class MsprobeException(CodedException):
27
27
  INVALID_PARAM_ERROR = 0
28
28
  OVERFLOW_NUMS_ERROR = 1
29
29
  RECURSION_LIMIT_ERROR = 2
30
+ INTERFACE_USAGE_ERROR = 3
30
31
 
31
32
  err_strs = {
32
33
  INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
33
34
  OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
34
- RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:"
35
+ RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
36
+ INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
35
37
  }
36
38
 
37
39
 
@@ -22,7 +22,6 @@ import re
22
22
  import shutil
23
23
  from datetime import datetime, timezone
24
24
  from dateutil import parser
25
- import OpenSSL
26
25
  import yaml
27
26
  import numpy as np
28
27
  import pandas as pd
@@ -419,20 +418,36 @@ def save_yaml(yaml_path, data):
419
418
 
420
419
 
421
420
  def save_excel(path, data):
421
+ def validate_data(data):
422
+ """Validate that the data is a DataFrame or a list of (DataFrame, sheet_name) pairs."""
423
+ if isinstance(data, pd.DataFrame):
424
+ return "single"
425
+ elif isinstance(data, list):
426
+ if all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], pd.DataFrame) for item in data):
427
+ return "list"
428
+ raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
429
+
422
430
  check_path_before_create(path)
423
431
  path = os.path.realpath(path)
432
+
433
+ # 验证数据类型
434
+ data_type = validate_data(data)
435
+
424
436
  try:
425
- if isinstance(data, pd.DataFrame):
437
+ if data_type == "single":
426
438
  data.to_excel(path, index=False)
427
- else:
428
- logger.error(f'unsupported data type.')
429
- return
439
+ elif data_type == "list":
440
+ with pd.ExcelWriter(path) as writer:
441
+ for data_df, sheet_name in data:
442
+ data_df.to_excel(writer, sheet_name=sheet_name, index=False)
430
443
  except Exception as e:
431
444
  logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
432
445
  raise RuntimeError(f"Save excel file {path} failed.") from e
433
446
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
434
447
 
435
448
 
449
+
450
+
436
451
  def move_file(src_path, dst_path):
437
452
  check_file_or_directory_path(src_path)
438
453
  check_path_before_create(dst_path)
@@ -522,11 +537,11 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
522
537
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
523
538
 
524
539
 
525
- def read_csv(filepath, as_pd=True):
540
+ def read_csv(filepath, as_pd=True, header='infer'):
526
541
  check_file_or_directory_path(filepath)
527
542
  try:
528
543
  if as_pd:
529
- csv_data = pd.read_csv(filepath)
544
+ csv_data = pd.read_csv(filepath, header=header)
530
545
  else:
531
546
  with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
532
547
  csv_reader = csv.reader(f, delimiter=',')
@@ -630,6 +645,7 @@ def check_crt_valid(pem_path):
630
645
  Raises:
631
646
  RuntimeError: If the SSL certificate is invalid or expired.
632
647
  """
648
+ import OpenSSL
633
649
  try:
634
650
  with FileOpen(pem_path, "r") as f:
635
651
  pem_data = f.read()
@@ -645,3 +661,13 @@ def check_crt_valid(pem_path):
645
661
  now_utc = datetime.now(tz=timezone.utc)
646
662
  if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
647
663
  raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
664
+
665
+
666
+ def read_xlsx(file_path):
667
+ check_file_or_directory_path(file_path)
668
+ try:
669
+ result_df = pd.read_excel(file_path, keep_default_na=False)
670
+ except Exception as e:
671
+ logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
672
+ raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
673
+ return result_df
@@ -157,6 +157,9 @@ inplace_tensor_op:
157
157
  - trunc_
158
158
  - unsqueeze_
159
159
  - xlogy_
160
+ - bitwise_left_shift_
161
+ - bitwise_right_shift_
162
+ - arctan2_
160
163
 
161
164
  inplace_torch_op:
162
165
  - _add_relu_
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,6 +29,7 @@ from msprobe.core.common.const import Const, CompareConst
29
29
  from msprobe.core.common.log import logger
30
30
  from msprobe.core.common.exceptions import MsprobeException
31
31
 
32
+
32
33
  device = collections.namedtuple('device', ['type', 'index'])
33
34
  prefixes = ['api_stack', 'list', 'range', 'acl']
34
35
 
@@ -71,6 +72,9 @@ class MsprobeBaseException(Exception):
71
72
  BACKWARD_DATA_COLLECTION_ERROR = 30
72
73
  INVALID_KEY_ERROR = 31
73
74
  MISSING_HEADER_ERROR = 32
75
+ MERGE_COMPARE_RESULT_ERROR = 33
76
+ NAMES_STRUCTS_MATCH_ERROR = 34
77
+ INVALID_STATE_ERROR = 35
74
78
 
75
79
  def __init__(self, code, error_info: str = ""):
76
80
  super(MsprobeBaseException, self).__init__()
@@ -109,7 +113,7 @@ def is_json_file(file_path):
109
113
  return False
110
114
 
111
115
 
112
- def check_compare_param(input_param, output_path, dump_mode):
116
+ def check_compare_param(input_param, output_path, dump_mode, stack_mode):
113
117
  if not isinstance(input_param, dict):
114
118
  logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
115
119
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -127,7 +131,8 @@ def check_compare_param(input_param, output_path, dump_mode):
127
131
 
128
132
  check_json_path("npu_json_path")
129
133
  check_json_path("bench_json_path")
130
- check_json_path("stack_json_path")
134
+ if stack_mode:
135
+ check_json_path("stack_json_path")
131
136
 
132
137
  if dump_mode == Const.ALL:
133
138
  check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
@@ -135,9 +140,12 @@ def check_compare_param(input_param, output_path, dump_mode):
135
140
  check_file_or_directory_path(output_path, True)
136
141
 
137
142
  with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
138
- FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
139
- FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
140
- check_json_file(input_param, npu_json, bench_json, stack_json)
143
+ FileOpen(input_param.get("bench_json_path"), "r") as bench_json:
144
+ _check_json(npu_json, input_param.get("npu_json_path"))
145
+ _check_json(bench_json, input_param.get("bench_json_path"))
146
+ if stack_mode:
147
+ with FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
148
+ _check_json(stack_json, input_param.get("stack_json_path"))
141
149
 
142
150
 
143
151
  def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
@@ -395,20 +403,23 @@ def get_real_step_or_rank(step_or_rank_input, obj):
395
403
  if not is_int(element) and not isinstance(element, str):
396
404
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
397
405
  f"{obj} element {element} must be an integer or string.")
398
- if isinstance(element, int) and element < 0:
399
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
400
- f"Each element of {obj} must be non-negative, currently it is {element}.")
401
- if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
406
+ if is_int(element):
407
+ if not Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
408
+ raise MsprobeException(
409
+ MsprobeException.INVALID_PARAM_ERROR,
410
+ f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and "
411
+ f"{Const.STEP_RANK_MAXIMUM_VALUE}, currently it is {element}."
412
+ )
402
413
  real_step_or_rank.append(element)
403
- elif isinstance(element, str) and Const.HYPHEN in element:
404
- continual_step_or_rank = get_step_or_rank_from_string(element, obj)
405
- real_step_or_rank.extend(continual_step_or_rank)
414
+ continue
415
+ continual_step_or_rank = get_step_or_rank_from_string(element, obj)
416
+ real_step_or_rank.extend(continual_step_or_rank)
406
417
  real_step_or_rank = list(set(real_step_or_rank))
407
418
  real_step_or_rank.sort()
408
419
  return real_step_or_rank
409
420
 
410
421
 
411
- def check_seed_all(seed, mode):
422
+ def check_seed_all(seed, mode, rm_dropout):
412
423
  if is_int(seed):
413
424
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
414
425
  logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
@@ -419,6 +430,9 @@ def check_seed_all(seed, mode):
419
430
  if not isinstance(mode, bool):
420
431
  logger.error("seed_all mode must be bool.")
421
432
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
433
+ if not isinstance(rm_dropout, bool):
434
+ logger.error("The rm_dropout parameter must be bool.")
435
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
422
436
 
423
437
 
424
438
  def safe_get_value(container, index, container_name, key=None):
@@ -27,6 +27,7 @@ class CommonConfig:
27
27
  self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
28
28
  self.level = json_config.get('level')
29
29
  self.enable_dataloader = json_config.get('enable_dataloader', False)
30
+ self.async_dump = json_config.get("async_dump", False)
30
31
  self._check_config()
31
32
 
32
33
  def _check_config(self):
@@ -42,6 +43,11 @@ class CommonConfig:
42
43
  if not isinstance(self.enable_dataloader, bool):
43
44
  logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
44
45
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
46
+ if not isinstance(self.async_dump, bool):
47
+ logger.error_log_with_exp("async_dump is invalid, it should be a boolean",
48
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
49
+ elif self.async_dump:
50
+ logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
45
51
 
46
52
 
47
53
  class BaseConfig: