mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,6 +29,7 @@ class Const:
29
29
  SEP = "."
30
30
  REGEX_PREFIX_MAX_LENGTH = 20
31
31
  REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
32
+ REGEX_FORWARD_BACKWARD = r'\.(forward|backward)\.'
32
33
  FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
33
34
  STRING_BLACKLIST = r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]"
34
35
  COMMA = ","
@@ -65,6 +66,7 @@ class Const:
65
66
  ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
66
67
  SUMMARY = "summary"
67
68
  MD5 = "md5"
69
+ VALUE = "value"
68
70
  SUMMARY_MODE = [ALL, SUMMARY, MD5]
69
71
 
70
72
  WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
@@ -73,6 +75,7 @@ class Const:
73
75
 
74
76
  PKL_SUFFIX = ".pkl"
75
77
  NUMPY_SUFFIX = ".npy"
78
+ NUMPY_PATTERN = "*.npy"
76
79
  PT_SUFFIX = ".pt"
77
80
  ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
78
81
  TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
@@ -87,6 +90,8 @@ class Const:
87
90
  INPUT_KWARGS = 'input_kwargs'
88
91
  GRAD_INPUT = 'grad_input'
89
92
  GRAD_OUTPUT = 'grad_output'
93
+ PARAMS = 'parameters'
94
+ PARAMS_GRAD = 'parameters_grad'
90
95
  START = "start"
91
96
  STOP = "stop"
92
97
  ENV_ENABLE = "1"
@@ -98,20 +103,23 @@ class Const:
98
103
  FREE_BENCHMARK = "free_benchmark"
99
104
  RUN_UT = "run_ut"
100
105
  GRAD_PROBE = "grad_probe"
101
- TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
102
- DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
106
+ STRUCTURE = "structure"
107
+ TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE]
108
+ DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE]
103
109
  DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
104
110
  LEVEL_L0 = "L0"
105
111
  LEVEL_L1 = "L1"
106
112
  LEVEL_L2 = "L2"
107
113
  LEVEL_MIX = "mix"
108
- LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX]
114
+ LEVEL_DEBUG = "debug"
115
+ LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX, LEVEL_DEBUG]
109
116
  ATTR_NAME_PREFIX = "wrap_"
110
117
  ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX)
111
118
  KERNEL_DUMP = "kernel_dump"
112
119
  DATA = "data"
113
120
  PT_FRAMEWORK = "pytorch"
114
121
  MS_FRAMEWORK = "mindspore"
122
+ MT_FRAMEWORK = "mindtorch"
115
123
  UNKNOWN_FRAMEWORK = "unknown"
116
124
  DIRECTORY_LENGTH = 4096
117
125
  FILE_NAME_LENGTH = 255
@@ -122,9 +130,12 @@ class Const:
122
130
  NPU_LOWERCASE = 'npu'
123
131
  CPU_LOWERCASE = 'cpu'
124
132
  CUDA_LOWERCASE = 'cuda'
133
+ DEVICE = 'device'
125
134
  DISTRIBUTED = 'Distributed'
126
- DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
135
+ DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
127
136
  "Aten", "VF", "NPU", "Jit"]
137
+ MODULE_PREFIX = ["Module", "Cell"]
138
+ FORWARD_NAME_SUFFIX = ".forward"
128
139
 
129
140
  # struct json param
130
141
  ORIGIN_DATA = "origin_data"
@@ -145,10 +156,13 @@ class Const:
145
156
  SCOPE_ID_INDEX = -1
146
157
  SCOPE_DIRECTION_INDEX = -2
147
158
  TYPE_NAME_INDEX = -3
159
+ PARAMS_GRAD_TYPE_NAME_INDEX = -2
148
160
  LAYER_NAME_INDEX = -4
161
+ PARAMS_GRAD_NAME_INDEX = -3
149
162
  API_TYPE_INDEX = 0
150
163
  LEFT_MOVE_INDEX = -1
151
164
  RIGHT_MOVE_INDEX = 1
165
+ LAST_INDEX = -1
152
166
 
153
167
  TOP_LAYER = "TopLayer"
154
168
  CELL = "Cell"
@@ -162,12 +176,16 @@ class Const:
162
176
 
163
177
  CONVERT = {
164
178
  "int32_to_int64": ["torch.int32", "torch.int64"],
179
+ "int64_to_fp32": ["torch.int64", "torch.float32"]
165
180
  }
166
181
 
167
182
  CONVERT_API = {
168
- "int32_to_int64": ["cross_entropy"]
183
+ "int32_to_int64": ["cross_entropy"],
184
+ "int64_to_fp32": ["histc"]
169
185
  }
170
186
 
187
+ FA_SPECIAL_SPARSE_MODE = [2, 3, 4]
188
+
171
189
  FILL_CHAR_NUMS = 50
172
190
  TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
173
191
  WITHOUT_CALL_STACK = "The call stack retrieval failed."
@@ -179,6 +197,8 @@ class Const:
179
197
  STEP_RANK_MAXIMUM_VALUE = int(1e6)
180
198
 
181
199
  # data type const
200
+ TORCH_INT_DTYPE = ["torch.int8", "torch.int32", "torch.int64"]
201
+ TORCH_FLOAT_DTYPE = ["torch.bfloat16", "torch.float16", "torch.float32", "torch.float64"]
182
202
  FLOAT16 = "Float16"
183
203
  FLOAT32 = "Float32"
184
204
  BFLOAT16 = "BFloat16"
@@ -193,6 +213,23 @@ class Const:
193
213
  MEAN = 'Mean'
194
214
  NORM = 'Norm'
195
215
 
216
+ CODE_STACK = 'Code Stack'
217
+ OP_NAME = 'Op Name'
218
+ SCOPE_NAME = 'Scope Name'
219
+ CODE_STACKS = 'Code Stacks'
220
+ FILE_PATH = 'File Path'
221
+ NEW_LINE = '\n'
222
+ CSV_NEWLINE_SEPARATOR = ',\n'
223
+ # 分隔符常量
224
+ SCOPE_SEPARATOR = "/"
225
+ REPLACEMENT_CHARACTER = "_"
226
+
227
+ OPTIMIZER = "optimizer"
228
+ CLIP_GRAD = "clip_grad"
229
+ END_PREFIX = "end_"
230
+
231
+ TENSOR_STAT_LEN = 2
232
+
196
233
 
197
234
  class CompareConst:
198
235
  """
@@ -239,13 +276,58 @@ class CompareConst:
239
276
  INPUT_STRUCT = "input_struct"
240
277
  KWARGS_STRUCT = "kwargs_struct"
241
278
  OUTPUT_STRUCT = "output_struct"
279
+ PARAMS_STRUCT = "params_struct"
280
+ PARAMS_GRAD_STRUCT = "params_grad_struct"
242
281
  SUMMARY = "summary"
282
+ COMPARE_RESULT = "compare_result"
283
+ COMPARE_MESSAGE = "compare_message"
243
284
  MAX_EXCEL_LENGTH = 1048576
244
285
  YES = "Yes"
245
286
  NO = "No"
246
287
  STATISTICS_INDICATOR_NUM = 4
247
288
  EPSILON = 1e-10
248
289
  COMPARE_ENDS_SUCCESSFULLY = "msprobe compare ends successfully."
290
+ DEFAULT_RATIO_VALUE = 10000
291
+ THOUSANDTH_PASS_VALUE = 0.999
292
+ ZERO_SHAPE = '(0,)'
293
+
294
+ BENCHMARK_COMPARE_ALGORITHM_NAME = "标杆比对法"
295
+ ULP_COMPARE_ALGORITHM_NAME = "ULP误差比对法"
296
+ BINARY_CONSISTENCY_ALGORITHM_NAME = "二进制一致法"
297
+ ABSOLUTE_THRESHOLD_ALGORITHM_NAME = "绝对阈值法"
298
+ THOUSANDTH_STANDARD_ALGORITHM_NAME = "双千指标法"
299
+ ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME = "累积误差比对法"
300
+
301
+ ABSOLUTE_THRESHOLD = 'absolute_threshold'
302
+ BINARY_CONSISTENCY = 'binary_consistency'
303
+ ULP_COMPARE = 'ulp_compare'
304
+ THOUSANDTH_STANDARD = 'thousandth_threshold'
305
+ BENCHMARK = 'benchmark'
306
+ ACCUMULATIVE_ERROR_COMPARE = 'accumulative_error_compare'
307
+
308
+ SMALL_VALUE_ERR_RATIO = "small_value_err_ratio"
309
+ RMSE_RATIO = "rmse_ratio"
310
+ MAX_REL_ERR_RATIO = "max_rel_err_ratio"
311
+ MEAN_REL_ERR_RATIO = "mean_rel_err_ratio"
312
+ EB_RATIO = "eb_ratio"
313
+
314
+ SMALL_VALUE = "small_value"
315
+ RMSE = "rmse"
316
+ MAX_REL_ERR = "max_rel_err"
317
+ MEAN_REL_ERR = "mean_rel_err"
318
+ EB = "eb"
319
+
320
+ SMALL_VALUE_ERR_STATUS = "small_value_err_status"
321
+ RMSE_STATUS = "rmse_status"
322
+ MAX_REL_ERR_STATUS = "max_rel_err_status"
323
+ MEAN_REL_ERR_STATUS = "mean_rel_err_status"
324
+ EB_STATUS = "eb_status"
325
+
326
+ MEAN_ULP_ERR = "mean_ulp_err"
327
+ ULP_ERR_PROPORTION = "ulp_err_proportion"
328
+ ULP_ERR_PROPORTION_RATIO = "ulp_err_proportion_ratio"
329
+
330
+ ULP_ERR_STATUS = "ulp_err_status"
249
331
 
250
332
  COMPARE_RESULT_HEADER = [
251
333
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
@@ -263,12 +345,57 @@ class CompareConst:
263
345
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
264
346
  ]
265
347
 
348
+ COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK]
349
+
350
+ SUMMARY_COMPARE_RESULT_HEADER_STACK = SUMMARY_COMPARE_RESULT_HEADER + [STACK]
351
+
352
+ MD5_COMPARE_RESULT_HEADER_STACK = MD5_COMPARE_RESULT_HEADER + [STACK]
353
+
266
354
  HEAD_OF_COMPARE_MODE = {
267
355
  Const.ALL: COMPARE_RESULT_HEADER,
268
356
  Const.SUMMARY: SUMMARY_COMPARE_RESULT_HEADER,
269
357
  Const.MD5: MD5_COMPARE_RESULT_HEADER
270
358
  }
271
359
 
360
+ ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
361
+ SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
362
+ MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
363
+
364
+ # dtype match
365
+ MS_TYPE = [
366
+ [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
367
+ [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
368
+ ]
369
+ TORCH_TYPE = [
370
+ [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
371
+ [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
372
+ ]
373
+
374
+ # read_op
375
+ IO_NAME_MAPPING = {
376
+ Const.INPUT_ARGS: '.input',
377
+ Const.INPUT_KWARGS: '.input',
378
+ Const.INPUT: '.input',
379
+ Const.OUTPUT: '.output',
380
+ Const.PARAMS: '.parameters'
381
+ }
382
+
383
+ # state to struct mapping
384
+ STATE_TO_STRUCT_MAPPING = {
385
+ Const.INPUT: INPUT_STRUCT,
386
+ Const.KWARGS: INPUT_STRUCT,
387
+ Const.OUTPUT: OUTPUT_STRUCT,
388
+ Const.PARAMS: PARAMS_STRUCT,
389
+ Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT
390
+ }
391
+
392
+ STRUCT_COMPARE_KEY = [
393
+ INPUT_STRUCT,
394
+ OUTPUT_STRUCT,
395
+ PARAMS_STRUCT,
396
+ PARAMS_GRAD_STRUCT
397
+ ]
398
+
272
399
  # compare standard
273
400
  HUNDRED_RATIO_THRESHOLD = 0.01
274
401
  THOUSAND_RATIO_THRESHOLD = 0.001
@@ -350,6 +477,8 @@ class CompareConst:
350
477
  INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
351
478
  KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
352
479
  OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
480
+ PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP
481
+ PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP
353
482
  COMPARE_KEY = 'compare_key'
354
483
  COMPARE_SHAPE = 'compare_shape'
355
484
  INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
@@ -372,13 +501,17 @@ class FileCheckConst:
372
501
  JSON_SUFFIX = ".json"
373
502
  PT_SUFFIX = ".pt"
374
503
  CSV_SUFFIX = ".csv"
504
+ XLSX_SUFFIX = ".xlsx"
375
505
  YAML_SUFFIX = ".yaml"
506
+ IR_SUFFIX = ".ir"
376
507
  MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
377
508
  MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
378
509
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
379
510
  MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
380
511
  MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
512
+ MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
381
513
  MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
514
+ MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
382
515
  COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
383
516
  DIR = "dir"
384
517
  FILE = "file"
@@ -390,7 +523,9 @@ class FileCheckConst:
390
523
  JSON_SUFFIX: MAX_JSON_SIZE,
391
524
  PT_SUFFIX: MAX_PT_SIZE,
392
525
  CSV_SUFFIX: MAX_CSV_SIZE,
393
- YAML_SUFFIX: MAX_YAML_SIZE
526
+ XLSX_SUFFIX: MAX_XLSX_SIZE,
527
+ YAML_SUFFIX: MAX_YAML_SIZE,
528
+ IR_SUFFIX: MAX_IR_SIZE
394
529
  }
395
530
  CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
396
531
 
@@ -410,9 +545,24 @@ class MsCompareConst:
410
545
  TENSOR_API = "Tensor"
411
546
 
412
547
  API_NAME_STR_LENGTH = 4
548
+ MAX_RECURSION_DEPTH = 20
549
+
550
+ # Mindtorch api_info field
551
+ MINDTORCH_TENSOR = "Tensor"
552
+ MINDTORCH = "Torch"
553
+ MINDTORCH_FUNC = "Functional"
554
+ MINDTORCH_NPU = "NPU"
555
+ MINDTORCH_DIST = "Distributed"
556
+
557
+
558
+
559
+ MT_VALID_API_TYPES = [
560
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
561
+ ]
413
562
 
414
563
  TASK_FIELD = "task"
415
564
  STATISTICS_TASK = "statistics"
565
+ FRAMEWORK = "framework"
416
566
  TENSOR_TASK = "tensor"
417
567
  DUMP_DATA_DIR_FIELD = "dump_data_dir"
418
568
  DATA_FIELD = "data"
@@ -437,6 +587,11 @@ class MsCompareConst:
437
587
 
438
588
  EPSILON = 1e-8
439
589
 
590
+ class ProcessStatus:
591
+ SUCCESS = "success"
592
+ API_NOT_FOUND = "api_not_found"
593
+ EXCEPTION_SKIP = "exception_skip"
594
+
440
595
 
441
596
  class MsgConst:
442
597
  """
@@ -474,29 +629,48 @@ class MonitorConst:
474
629
  """
475
630
  Class for monitor const
476
631
  """
477
- OP_LIST = ["min", "max", "norm", "zeros", "nans", "id", "mean"]
632
+ OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
478
633
  MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
479
634
  DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
480
635
  DATABASE = "database"
481
636
  EMAIL = "email"
482
637
  OPT_TY = ['Megatron_DistributedOptimizer', 'Megatron_Float16OptimizerWithFloat16Params']
483
- DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage0", "DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3")
638
+ DEEPSPEED_OPT_TY = (
639
+ "DeepSpeedZeroOptimizer_Stage0",
640
+ "DeepSpeedZeroOptimizer_Stage1_or_2",
641
+ "DeepSpeedZeroOptimizer_Stage3"
642
+ )
643
+ DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer"
484
644
  RULE_NAME = ['AnomalyTurbulence']
485
645
 
646
+ SLICE_SIZE = 20480
647
+ # used for name
486
648
  DOT = "."
487
- VPP_SEP = ":"
649
+ NAME_SEP = ":"
650
+ INPUT_GRAD = "input_grad"
651
+ OUTPUT_GRAD = "output_grad"
488
652
  ACTV_IN = "input"
489
653
  ACTV_OUT = "output"
490
654
  ACTVGRAD_IN = "input_grad"
491
655
  ACTVGRAD_OUT = "output_grad"
656
+ # used for tasks
657
+ ACTV = "actv"
658
+ ACTVGRAD = "actv_grad"
492
659
  POST_GRAD = "post_grad"
493
660
  PRE_GRAD = "pre_grad"
661
+ ACC_GRAD = "acc_grad"
494
662
  PREFIX_POST = "post"
495
663
  PREFIX_PRE = "pre"
664
+ EXP_AVG = "exp_avg"
665
+ EXP_AVG_SQ = "exp_avg_sq"
666
+ PARAM = "param"
496
667
 
497
-
668
+ CSV_HEADER = ["vpp_stage", "name", "step"]
669
+ CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"]
670
+ OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
498
671
  ANOMALY_JSON = "anomaly.json"
499
672
  ANALYSE_JSON = "anomaly_analyse.json"
500
673
  TENSORBOARD = "tensorboard"
501
674
  CSV = "csv"
502
675
  API = "api"
676
+ HEADER_NAME = 'name'
@@ -27,11 +27,13 @@ class MsprobeException(CodedException):
27
27
  INVALID_PARAM_ERROR = 0
28
28
  OVERFLOW_NUMS_ERROR = 1
29
29
  RECURSION_LIMIT_ERROR = 2
30
+ INTERFACE_USAGE_ERROR = 3
30
31
 
31
32
  err_strs = {
32
33
  INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
33
34
  OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
34
- RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:"
35
+ RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
36
+ INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
35
37
  }
36
38
 
37
39
 
@@ -22,7 +22,6 @@ import re
22
22
  import shutil
23
23
  from datetime import datetime, timezone
24
24
  from dateutil import parser
25
- import OpenSSL
26
25
  import yaml
27
26
  import numpy as np
28
27
  import pandas as pd
@@ -419,20 +418,36 @@ def save_yaml(yaml_path, data):
419
418
 
420
419
 
421
420
  def save_excel(path, data):
421
+ def validate_data(data):
422
+ """Validate that the data is a DataFrame or a list of (DataFrame, sheet_name) pairs."""
423
+ if isinstance(data, pd.DataFrame):
424
+ return "single"
425
+ elif isinstance(data, list):
426
+ if all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], pd.DataFrame) for item in data):
427
+ return "list"
428
+ raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
429
+
422
430
  check_path_before_create(path)
423
431
  path = os.path.realpath(path)
432
+
433
+ # 验证数据类型
434
+ data_type = validate_data(data)
435
+
424
436
  try:
425
- if isinstance(data, pd.DataFrame):
437
+ if data_type == "single":
426
438
  data.to_excel(path, index=False)
427
- else:
428
- logger.error(f'unsupported data type.')
429
- return
439
+ elif data_type == "list":
440
+ with pd.ExcelWriter(path) as writer:
441
+ for data_df, sheet_name in data:
442
+ data_df.to_excel(writer, sheet_name=sheet_name, index=False)
430
443
  except Exception as e:
431
444
  logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
432
445
  raise RuntimeError(f"Save excel file {path} failed.") from e
433
446
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
434
447
 
435
448
 
449
+
450
+
436
451
  def move_file(src_path, dst_path):
437
452
  check_file_or_directory_path(src_path)
438
453
  check_path_before_create(dst_path)
@@ -522,11 +537,11 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
522
537
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
523
538
 
524
539
 
525
- def read_csv(filepath, as_pd=True):
540
+ def read_csv(filepath, as_pd=True, header='infer'):
526
541
  check_file_or_directory_path(filepath)
527
542
  try:
528
543
  if as_pd:
529
- csv_data = pd.read_csv(filepath)
544
+ csv_data = pd.read_csv(filepath, header=header)
530
545
  else:
531
546
  with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
532
547
  csv_reader = csv.reader(f, delimiter=',')
@@ -630,6 +645,7 @@ def check_crt_valid(pem_path):
630
645
  Raises:
631
646
  RuntimeError: If the SSL certificate is invalid or expired.
632
647
  """
648
+ import OpenSSL
633
649
  try:
634
650
  with FileOpen(pem_path, "r") as f:
635
651
  pem_data = f.read()
@@ -645,3 +661,13 @@ def check_crt_valid(pem_path):
645
661
  now_utc = datetime.now(tz=timezone.utc)
646
662
  if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
647
663
  raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
664
+
665
+
666
+ def read_xlsx(file_path):
667
+ check_file_or_directory_path(file_path)
668
+ try:
669
+ result_df = pd.read_excel(file_path, keep_default_na=False)
670
+ except Exception as e:
671
+ logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
672
+ raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
673
+ return result_df
@@ -157,6 +157,9 @@ inplace_tensor_op:
157
157
  - trunc_
158
158
  - unsqueeze_
159
159
  - xlogy_
160
+ - bitwise_left_shift_
161
+ - bitwise_right_shift_
162
+ - arctan2_
160
163
 
161
164
  inplace_torch_op:
162
165
  - _add_relu_
@@ -247,5 +250,6 @@ inplace_distributed_op:
247
250
  - all_to_all
248
251
  - all_gather_into_tensor
249
252
  - reduce_scatter_tensor
253
+ - batch_isend_irecv
250
254
 
251
255
 
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,6 +29,7 @@ from msprobe.core.common.const import Const, CompareConst
29
29
  from msprobe.core.common.log import logger
30
30
  from msprobe.core.common.exceptions import MsprobeException
31
31
 
32
+
32
33
  device = collections.namedtuple('device', ['type', 'index'])
33
34
  prefixes = ['api_stack', 'list', 'range', 'acl']
34
35
 
@@ -71,6 +72,9 @@ class MsprobeBaseException(Exception):
71
72
  BACKWARD_DATA_COLLECTION_ERROR = 30
72
73
  INVALID_KEY_ERROR = 31
73
74
  MISSING_HEADER_ERROR = 32
75
+ MERGE_COMPARE_RESULT_ERROR = 33
76
+ NAMES_STRUCTS_MATCH_ERROR = 34
77
+ INVALID_STATE_ERROR = 35
74
78
 
75
79
  def __init__(self, code, error_info: str = ""):
76
80
  super(MsprobeBaseException, self).__init__()
@@ -109,7 +113,7 @@ def is_json_file(file_path):
109
113
  return False
110
114
 
111
115
 
112
- def check_compare_param(input_param, output_path, dump_mode):
116
+ def check_compare_param(input_param, output_path, dump_mode, stack_mode):
113
117
  if not isinstance(input_param, dict):
114
118
  logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
115
119
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -127,7 +131,8 @@ def check_compare_param(input_param, output_path, dump_mode):
127
131
 
128
132
  check_json_path("npu_json_path")
129
133
  check_json_path("bench_json_path")
130
- check_json_path("stack_json_path")
134
+ if stack_mode:
135
+ check_json_path("stack_json_path")
131
136
 
132
137
  if dump_mode == Const.ALL:
133
138
  check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
@@ -135,9 +140,12 @@ def check_compare_param(input_param, output_path, dump_mode):
135
140
  check_file_or_directory_path(output_path, True)
136
141
 
137
142
  with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
138
- FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
139
- FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
140
- check_json_file(input_param, npu_json, bench_json, stack_json)
143
+ FileOpen(input_param.get("bench_json_path"), "r") as bench_json:
144
+ _check_json(npu_json, input_param.get("npu_json_path"))
145
+ _check_json(bench_json, input_param.get("bench_json_path"))
146
+ if stack_mode:
147
+ with FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
148
+ _check_json(stack_json, input_param.get("stack_json_path"))
141
149
 
142
150
 
143
151
  def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
@@ -231,6 +239,8 @@ def md5_find(data):
231
239
  for data_detail in data[key_op][api_info]:
232
240
  if data_detail and 'md5' in data_detail:
233
241
  return True
242
+ if isinstance(data[key_op][api_info], bool):
243
+ continue
234
244
  elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
235
245
  return True
236
246
  return False
@@ -295,6 +305,9 @@ def get_dump_mode(input_param):
295
305
  if npu_task == Const.TENSOR:
296
306
  return Const.ALL
297
307
 
308
+ if npu_task == Const.STRUCTURE:
309
+ return Const.STRUCTURE
310
+
298
311
  if npu_task == Const.STATISTICS:
299
312
  npu_md5_compare = md5_find(npu_json_data['data'])
300
313
  bench_md5_compare = md5_find(bench_json_data['data'])
@@ -395,20 +408,23 @@ def get_real_step_or_rank(step_or_rank_input, obj):
395
408
  if not is_int(element) and not isinstance(element, str):
396
409
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
397
410
  f"{obj} element {element} must be an integer or string.")
398
- if isinstance(element, int) and element < 0:
399
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
400
- f"Each element of {obj} must be non-negative, currently it is {element}.")
401
- if isinstance(element, int) and Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
411
+ if is_int(element):
412
+ if not Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
413
+ raise MsprobeException(
414
+ MsprobeException.INVALID_PARAM_ERROR,
415
+ f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and "
416
+ f"{Const.STEP_RANK_MAXIMUM_VALUE}, currently it is {element}."
417
+ )
402
418
  real_step_or_rank.append(element)
403
- elif isinstance(element, str) and Const.HYPHEN in element:
404
- continual_step_or_rank = get_step_or_rank_from_string(element, obj)
405
- real_step_or_rank.extend(continual_step_or_rank)
419
+ continue
420
+ continual_step_or_rank = get_step_or_rank_from_string(element, obj)
421
+ real_step_or_rank.extend(continual_step_or_rank)
406
422
  real_step_or_rank = list(set(real_step_or_rank))
407
423
  real_step_or_rank.sort()
408
424
  return real_step_or_rank
409
425
 
410
426
 
411
- def check_seed_all(seed, mode):
427
+ def check_seed_all(seed, mode, rm_dropout):
412
428
  if is_int(seed):
413
429
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
414
430
  logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
@@ -419,6 +435,9 @@ def check_seed_all(seed, mode):
419
435
  if not isinstance(mode, bool):
420
436
  logger.error("seed_all mode must be bool.")
421
437
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
438
+ if not isinstance(rm_dropout, bool):
439
+ logger.error("The rm_dropout parameter must be bool.")
440
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
422
441
 
423
442
 
424
443
  def safe_get_value(container, index, container_name, key=None):
@@ -482,3 +501,12 @@ def check_str_param(param):
482
501
  if not re.match(Const.REGEX_PREFIX_PATTERN, param):
483
502
  logger.error('The parameter {} contains special characters.'.format(param))
484
503
  raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
504
+
505
+
506
+ class DumpPathAggregation:
507
+ dump_file_path = None
508
+ stack_file_path = None
509
+ construct_file_path = None
510
+ dump_tensor_data_dir = None
511
+ free_benchmark_file_path = None
512
+ debug_file_path = None
@@ -27,6 +27,7 @@ class CommonConfig:
27
27
  self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
28
28
  self.level = json_config.get('level')
29
29
  self.enable_dataloader = json_config.get('enable_dataloader', False)
30
+ self.async_dump = json_config.get("async_dump", False)
30
31
  self._check_config()
31
32
 
32
33
  def _check_config(self):
@@ -42,6 +43,11 @@ class CommonConfig:
42
43
  if not isinstance(self.enable_dataloader, bool):
43
44
  logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
44
45
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
46
+ if not isinstance(self.async_dump, bool):
47
+ logger.error_log_with_exp("async_dump is invalid, it should be a boolean",
48
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
49
+ elif self.async_dump:
50
+ logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
45
51
 
46
52
 
47
53
  class BaseConfig: