mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,132 @@
1
+ # 模型分级可视化如何配置layer mapping映射文件
2
+
3
+ ## 1.使用场景
4
+ 同框架跨套件比对(例如PyTorch DeepSpeed vs Megatron),或者跨框架比对(例如PyTorch vs MindSpore),**由于代码实现的差异,导致一些模型层级和层级命名有所不同无法进行匹配**,需要进行layer层名称映射,才能够比对。
5
+
6
+ ## 2.模块命名说明
7
+
8
+ 由于有些节点的名称比较长,例如Module.module.module.language_model.embedding.Embedding.forward.0,在图节点上由于字符串过长无法完整显示,forward或backward信息被省略,**因此节点中显示的名称字符串去掉了Module前缀,并将forward或backward信息提取到名称字符串的第二位展示**。
9
+
10
+ ![module_name.png](./module_name.png)
11
+
12
+ ![module_name1.png](./module_name1.png)
13
+
14
+ ### 2.1 命名格式
15
+
16
+ **{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}**
17
+
18
+ **layer mapping主要是针对module_name的映射**
19
+
20
+ #### 2.1.1 命名示例
21
+
22
+ - **Module.module.Float16Module.forward.0** -----> Module{**Module**}.module{**module_name**}.Float16Module{**class_name**}.forward.0{**调用次数**}
23
+ - **Module.module.module.GPTModel.forward.0** -----> Module{**Module**}.module.module{**module_name**}.GPTModel{**class_name**}.forward.0{**调用次数**}
24
+ - **Module.module.module.language_model.TransformerLanguageModel.forward.0** -----> Module{**Module**}.module.module.language_model{**module_name**}.TransformerLanguageModel{**class_name**}.forward.0{**调用次数**}
25
+ - **Module.module.module.language_model.embedding.Embedding.forward.0** -----> Module{**Module**}.module.module.language_model.embedding{**module_name**}.Embedding{**class_name**}.forward.0{**调用次数**}
26
+
27
+ 可以看到,module_name随着模型层级的深入在变长,**embedding层module_name拼接了它的上层language_model、上上层module和顶层module**。
28
+
29
+ ## 3.示例
30
+
31
+ 如图所示,左边为NPU模型,右边为GPU模型,由于代码实现上的差异,导致模型层级和层级命名有所不同,导致节点无法匹配,**图上节点显示为灰色,表示节点未匹配**。
32
+
33
+ ![no_mapping.png](./no_mapping.png)
34
+
35
+ ### 3.1 看图分析
36
+
37
+ 同一模型使用了不同套件或者框架,虽然两个模型的层级关系和层级命名可能有所不同,但也可以从图上的**节点名称**看出一些匹配关系,例如同是embedding层,代码里也是会命名为xxx_embedding,不会命名为xxx_norm,体现在节点名称上也是带有embedding的信息,并且层级关系也是大致相同的。
38
+
39
+ ![no_mapping_analyze.png](./no_mapping_analyze.png)
40
+
41
+ 分析可知,节点匹配关系如下:
42
+
43
+ **注意,仅需关注module_name的差异**
44
+
45
+ | NPU节点名称 | GPU节点名称 | module_name差异 |
46
+ |-------------------|----------------------------------------------------------------|---------------------------|
47
+ | Module.module.Float16Module.forward.0 | Module.model.FloatModule.forward.0 | NPU为module,GPU为model |
48
+ | Module.module.module.GPTModel.forward.0 | Module.model.module.GPT2Model.forward.0 | NPU为module,GPU为module,无差异 |
49
+ | Module.module.module.language_model.TransformerLanguageModel.forward.0 | 无 | NPU多了一层 |
50
+ | Module.module.module.language_model.embedding.Embedding.forward.0 | Module.module.module.embedding.LanguageModelEmbedding.forward.0 | NPU为language_model.embedding,GPU为embedding |
51
+ | Module.module.module.language_model.rotary_pos_emb.RotaryEmbedding.forward.0 | Module.module.module.rotary_pos_emb.RotaryEmbedding.forward.0 | NPU为language_model.rotary_pos_emb,GPU为rotary_pos_emb |
52
+ | Module.module.module.language_model.encoder.ParallelTransformer.forward.0 | Module.module.module.decoder.TransformerBlock.forward.0 | NPU为language_model.encoder,GPU为decoder |
53
+ | Module.module.module.language_model.encoder.layers.0.ParallelTransformerLayer.forward.0 | Module.module.module.decoder.layers.0.TransformerLayer.forward.0 | 父层级有差异,本层级NPU和GPU都叫layers,无差异 |
54
+
55
+ ### 3.2 构建layer_mapping配置文件
56
+ 准备一个命名为mapping.yaml文件,建立**module_name**的映射关系
57
+
58
+ #### 3.2.1 顶层模块映射
59
+ NPU和GPU侧的模块Module.module.Float16Module.forward.0和Module.model.FloatModule.forward.0处于图的顶层,需要进行如下配置:
60
+
61
+ ![top_layer.png](./top_layer.png)
62
+
63
+ ```yaml
64
+ TopLayer:
65
+ module: model
66
+ ```
67
+
68
+ #### 3.2.2 其他模块映射
69
+ 配置module下的子模块,虽然两边的class_name不同(NPU侧为GPTModel,GPU侧为GPT2Model),**但是仅需取NPU侧也就是左边图的class_name进行配置,无需关心右边图的class_name叫什么**。
70
+
71
+ **这里涉及到跨层级的配置,NPU多了一层language_model层**,将language_model作为embedding层、rotary_pos_emb层和encoder层的前缀,进行如下配置:
72
+
73
+ ![GPTModel.png](./GPTModel.png)
74
+
75
+ ```yaml
76
+ GPTModel:
77
+ language_model.embedding: embedding
78
+ language_model.rotary_pos_emb: rotary_pos_emb
79
+ language_model.encoder: decoder
80
+ ```
81
+ 然后看Module.module.module.language_model.encoder.ParallelTransformer.forward.0层下的子模块:
82
+
83
+ 此层下的若干个层,NPU和GPU的层名都叫layers,**当前层名称相同,则不用进行配置**。
84
+
85
+ ### 3.3 查看效果
86
+
87
+ 执行命令,指定-lm:
88
+ ```
89
+ msprobe -f pytorch graph -i ./compare.json -o ./output -lm ./mapping.yaml
90
+ ```
91
+
92
+ ```
93
+ msprobe -f mindspore graph -i ./compare.json -o ./output -lm ./mapping.yaml
94
+ ```
95
+ 可以看到,除了language_model层(NPU多的一层,GPU没有层与其匹配),其余在mapping.yaml文件配置的层均匹配上了。
96
+
97
+ ![mapping.png](./mapping.png)
98
+
99
+ ### 3.4 继续配置
100
+
101
+ 展开节点过程中,如果发现还有未匹配节点,则继续配置mapping.yaml
102
+
103
+ ![no_mapping1.png](./no_mapping1.png)
104
+
105
+ 按前一章过程进行分析配置,分析可知,节点匹配关系如下:
106
+
107
+ | NPU节点名称 | GPU节点名称 | 差异 |
108
+ |-------------------|------------------------------------------------------------------|---------------------------------------------|
109
+ | Module.module.module.language_model.encoder.layers.0.mlp.dense_h_to_4h.ColumnParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc1.TELayerNormColumnParallelLinear.forward.0 | NPU为dense_h_to_4h,GPU为linear_fc1 |
110
+ | Module.module.module.language_model.encoder.layers.0.mlp.dense_4h_to_h.RowParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc2.TERowParallelLinear.forward.0 | NPU为dense_4h_to_h,GPU为linear_fc2 |
111
+
112
+ ![ParallelMLP.png](./ParallelMLP.png)
113
+
114
+ 追加mapping.yaml配置:
115
+
116
+ ```yaml
117
+ TopLayer:
118
+ module: model
119
+
120
+ GPTModel:
121
+ language_model.embedding: embedding
122
+ language_model.rotary_pos_emb: rotary_pos_emb
123
+ language_model.encoder: decoder
124
+
125
+ ParallelMLP:
126
+ dense_h_to_4h: linear_fc1
127
+ dense_4h_to_h: linear_fc2
128
+ ```
129
+
130
+ 执行命令,查看效果,可以看到节点已成功匹配上。
131
+
132
+ ![mapping1.png](./mapping1.png)
Binary file
Binary file
Binary file
@@ -13,5 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import os
17
+
18
+ try:
19
+ from msprobe.lib import _msprobe_c
20
+ os.environ["MS_HOOK_ENABLE"] = "on"
21
+ os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
22
+ except ImportError:
23
+ from .common.log import logger
24
+ logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
25
+
16
26
  from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
17
27
  from msprobe.mindspore.common.utils import seed_all
@@ -30,6 +30,7 @@ from msprobe.mindspore.common.log import logger
30
30
  cur_path = os.path.dirname(os.path.realpath(__file__))
31
31
  yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
32
32
 
33
+
33
34
  class BasicInfoAndStatus:
34
35
  def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
35
36
  self.api_name = api_name
@@ -49,6 +50,13 @@ class ResultCsvEntry:
49
50
  self.overall_err_msg = None
50
51
 
51
52
 
53
+ class ProcessResultPacket:
54
+ def __init__(self, process_status, result, err_msg) -> None:
55
+ self.process_status = process_status
56
+ self.result = result
57
+ self.err_msg = err_msg
58
+
59
+
52
60
  class ApiAccuracyChecker:
53
61
  def __init__(self, args):
54
62
  self.api_infos = dict()
@@ -56,7 +64,7 @@ class ApiAccuracyChecker:
56
64
 
57
65
  @staticmethod
58
66
  def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
59
- '''
67
+ """
60
68
  Args:
61
69
  api_info: ApiInfo
62
70
  api_name_str: str
@@ -70,7 +78,7 @@ class ApiAccuracyChecker:
70
78
  get mindspore api output, run torch api and get output.
71
79
  compare output.
72
80
  record compare result.
73
- '''
81
+ """
74
82
  # get output
75
83
  if global_context.get_is_constructed():
76
84
  # constructed situation, need use constructed input to run mindspore api getting tested_output
@@ -104,8 +112,8 @@ class ApiAccuracyChecker:
104
112
  err_msg = ""
105
113
  else:
106
114
  status = CompareConst.ERROR
107
- err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
108
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
115
+ err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
116
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
109
117
  basic_info_status = \
110
118
  BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
111
119
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
@@ -113,13 +121,13 @@ class ApiAccuracyChecker:
113
121
 
114
122
  @staticmethod
115
123
  def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
116
- '''
124
+ """
117
125
  Args:
118
126
  api_info: ApiInfo
119
127
  forward_or_backward: str
120
128
  Returns:
121
129
  ApiInputAggregation
122
- '''
130
+ """
123
131
  forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
124
132
  kwargs = api_info.get_kwargs()
125
133
  if forward_or_backward == Const.FORWARD:
@@ -162,7 +170,8 @@ class ApiAccuracyChecker:
162
170
  is_constructed = task == MsCompareConst.STATISTICS_TASK
163
171
  if not is_constructed:
164
172
  dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
165
- "dump_data_dir field in api_info.json", accepted_type=str)
173
+ "dump_data_dir field in api_info.json",
174
+ accepted_type=str)
166
175
  else:
167
176
  dump_data_dir = ""
168
177
  global_context.init(is_constructed, dump_data_dir)
@@ -188,45 +197,65 @@ class ApiAccuracyChecker:
188
197
  """处理前向检查"""
189
198
  if not api_info.check_forward_info():
190
199
  logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
191
- return Const.EXCEPTION_NONE
200
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
201
+ result=None,
202
+ err_msg=f"forward info of {api_name_str} is not found")
203
+ return process_result_packet
192
204
 
193
205
  try:
194
206
  forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
195
207
  except Exception as e:
196
208
  logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
197
209
  f"Skipping forward check. Detailed exception information: {e}.")
198
- return Const.EXCEPTION_NONE
210
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
211
+ result=None, err_msg=f"{e}")
212
+ return process_result_packet
199
213
 
200
- forward_output_list = None
201
214
  try:
202
- forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
215
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
216
+ Const.FORWARD)
203
217
  except Exception as e:
204
218
  logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
205
219
  f"Detailed exception information: {e}.")
206
- return forward_output_list
220
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
221
+ result=None, err_msg=f"{e}")
222
+ return process_result_packet
223
+
224
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
225
+ result=forward_output_list, err_msg="")
226
+ return process_result_packet
207
227
 
208
228
  def process_backward(self, api_name_str, api_info):
209
229
  """处理反向检查"""
210
230
  if not api_info.check_backward_info():
211
231
  logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
212
- return Const.EXCEPTION_NONE
232
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
233
+ result=None,
234
+ err_msg=f"backward info of {api_name_str} is not found")
235
+ return process_result_packet
213
236
 
214
237
  try:
215
238
  backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
216
239
  except Exception as e:
217
240
  logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
218
241
  f"Skipping backward check. Detailed exception information: {e}.")
219
- return Const.EXCEPTION_NONE
242
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
243
+ result=None, err_msg=f"{e}")
244
+ return process_result_packet
220
245
 
221
- backward_output_list = None
222
246
  try:
223
- backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
247
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
248
+ Const.BACKWARD)
224
249
  except Exception as e:
225
250
  logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
226
251
  f"Detailed exception information: {e}.")
227
- return backward_output_list
228
-
252
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
253
+ result=None, err_msg=f"{e}")
254
+ return process_result_packet
229
255
 
256
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
257
+ result=backward_output_list, err_msg="")
258
+ return process_result_packet
230
259
 
231
260
  def run_and_compare(self):
232
261
  for api_name_str, api_info in tqdm(self.api_infos.items()):
@@ -234,14 +263,17 @@ class ApiAccuracyChecker:
234
263
  continue
235
264
 
236
265
  # 处理前向
237
- forward_output_list = self.process_forward(api_name_str, api_info)
238
- if forward_output_list is not Const.EXCEPTION_NONE:
239
- self.data_manager.record(forward_output_list)
266
+ process_result_packet = self.process_forward(api_name_str, api_info)
267
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
268
+ self.data_manager.record(process_result_packet.result)
269
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
270
+ self.data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
240
271
 
241
272
  # 处理反向
242
- backward_output_list = self.process_backward(api_name_str, api_info)
243
- if backward_output_list is not Const.EXCEPTION_NONE:
244
- self.data_manager.record(backward_output_list)
273
+ process_result_packet = self.process_backward(api_name_str, api_info)
274
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
275
+ self.data_manager.record(process_result_packet.result)
276
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
277
+ self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
245
278
 
246
279
  self.data_manager.save_results(api_name_str)
247
-
@@ -16,10 +16,10 @@
16
16
  import argparse
17
17
  import os
18
18
 
19
-
20
19
  from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
21
20
  from msprobe.core.common.utils import Const, MsprobeBaseException
22
21
 
22
+
23
23
  class UniqueDeviceAction(argparse.Action):
24
24
  def __call__(self, parser, namespace, values, option_string=None):
25
25
  unique_values = set(values)
@@ -40,6 +40,7 @@ def add_api_accuracy_checker_argument(parser):
40
40
  parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
41
41
  help="<optional> the exit csv for continue")
42
42
 
43
+
43
44
  def multi_add_api_accuracy_checker_argument(parser):
44
45
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
45
46
  help="<Required> The api param tool result file: generate from api param tool, "
@@ -78,12 +78,10 @@ class ComputeElement:
78
78
  else:
79
79
  torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
80
80
 
81
- if dtype_str in float_dtype_str_list:
82
- middle_dtype = mindspore.float64
83
- elif dtype_str in int_dtype_str_list:
81
+ if dtype_str in int_dtype_str_list:
84
82
  middle_dtype = mindspore.int64
85
83
  else:
86
- middle_dtype = mindspore.uint64
84
+ middle_dtype = mindspore.float64
87
85
  np_ndarray = ms_tensor.astype(middle_dtype).numpy()
88
86
  torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
89
87
  return torch_tensor
@@ -106,10 +104,10 @@ class ComputeElement:
106
104
  else:
107
105
  ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
108
106
 
109
- if dtype_str in float_dtype_str_list:
110
- middle_dtype = torch.float64
111
- elif dtype_str in int_dtype_str_list:
107
+ if dtype_str in int_dtype_str_list:
112
108
  middle_dtype = torch.int64
109
+ else:
110
+ middle_dtype = torch.float64
113
111
  np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
114
112
  ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
115
113
  return ms_tensor
@@ -80,6 +80,7 @@ def check_csv_header(headers, required_constants, csv_path):
80
80
  class DataManager:
81
81
  def __init__(self, csv_dir, result_csv_path):
82
82
  self.results = {}
83
+ self.results_exception_skip = {}
83
84
  self.is_first_write = True # 标记用于添加表头
84
85
  self.csv_dir = csv_dir
85
86
  self.api_names_set = set() # 存储已经出现的 API 名称的集合
@@ -184,10 +185,21 @@ class DataManager:
184
185
  logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
185
186
  logger.debug(f"Complete self.results after recording: {self.results}")
186
187
 
188
+ def record_exception_skip(self, api_name, forward_or_backward, err_msg):
189
+ '''
190
+ record exception_skip infomation into self.record_exception_skip.
191
+ self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
192
+ string in key is api_name, string in value is err_msg
193
+ '''
194
+ if api_name not in self.results_exception_skip:
195
+ self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
196
+ self.results_exception_skip[api_name][forward_or_backward] = err_msg
197
+
187
198
  def clear_results(self):
188
199
  """清空 self.results 数据"""
189
200
  logger.debug("Clearing self.results data.")
190
201
  self.results.clear()
202
+ self.results_exception_skip.clear()
191
203
 
192
204
  def to_detail_csv(self, csv_path):
193
205
  logger.debug("Preparing detail CSV headers and rows.")
@@ -218,6 +230,9 @@ class DataManager:
218
230
  logger.debug(f"Detail CSV written successfully to {csv_path}.")
219
231
 
220
232
  def to_result_csv(self, csv_path):
233
+ '''
234
+ depend on both self.results and self.results_exception_skip
235
+ '''
221
236
  logger.debug("Preparing result CSV data.")
222
237
  result_csv = []
223
238
 
@@ -254,8 +269,30 @@ class DataManager:
254
269
  entry.backward_pass_status,
255
270
  overall_err_msg
256
271
  ]
272
+ # change row if this api has excption_skip infomation
273
+ if api_name in self.results_exception_skip:
274
+ if self.results_exception_skip[api_name][Const.FORWARD] is not None:
275
+ row[1] = CompareConst.SKIP
276
+ row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
277
+ if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
278
+ row[2] = CompareConst.SKIP
279
+ row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
280
+ del self.results_exception_skip[api_name]
257
281
  result_csv.append(row)
258
282
  logger.debug(f"Result CSV row added: {row}")
283
+ for api_name in self.results_exception_skip:
284
+ current_exception_skip = self.results_exception_skip[api_name]
285
+ forward_status = None
286
+ backward_status = None
287
+ err_msg = ""
288
+ if current_exception_skip[Const.FORWARD] is not None:
289
+ forward_status = CompareConst.SKIP
290
+ err_msg += current_exception_skip[Const.FORWARD]
291
+ if current_exception_skip[Const.BACKWARD] is not None:
292
+ backward_status = CompareConst.SKIP
293
+ err_msg += current_exception_skip[Const.BACKWARD]
294
+ row = [api_name, forward_status, backward_status, err_msg]
295
+ result_csv.append(row)
259
296
 
260
297
  write_csv(result_csv, csv_path, mode="a+")
261
298
  logger.debug(f"Result CSV written successfully to {csv_path}.")
@@ -26,6 +26,7 @@ def api_checker_main(args):
26
26
  api_accuracy_checker.parse(args.api_info_file)
27
27
  api_accuracy_checker.run_and_compare()
28
28
 
29
+
29
30
  def mul_api_checker_main(args):
30
31
  check_args(args)
31
32
  api_accuracy_checker = MultiApiAccuracyChecker(args)
@@ -154,14 +154,16 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
154
154
  """
155
155
  if not api_info.check_forward_info():
156
156
  logger.debug(
157
- f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping forward check.")
157
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping "
158
+ f"forward check.")
158
159
  return Const.EXCEPTION_NONE
159
160
 
160
161
  try:
161
162
  forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
162
163
  except Exception as e:
163
164
  logger.warning(
164
- f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for {api_name_str}. Skipping forward check. Detailed exception information: {e}.")
165
+ f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for "
166
+ f"{api_name_str}. Skipping forward check. Detailed exception information: {e}.")
165
167
  return Const.EXCEPTION_NONE
166
168
 
167
169
  forward_output_list = None
@@ -170,7 +172,8 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
170
172
  Const.FORWARD)
171
173
  except Exception as e:
172
174
  logger.warning(
173
- f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} forward API. Detailed exception information: {e}.")
175
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
176
+ f"forward API. Detailed exception information: {e}.")
174
177
  return forward_output_list
175
178
 
176
179
  def process_backward(self, api_name_str, api_info):
@@ -186,14 +189,16 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
186
189
  """
187
190
  if not api_info.check_backward_info():
188
191
  logger.debug(
189
- f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping backward check.")
192
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping "
193
+ f"backward check.")
190
194
  return Const.EXCEPTION_NONE
191
195
 
192
196
  try:
193
197
  backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
194
198
  except Exception as e:
195
199
  logger.warning(
196
- f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for {api_name_str}. Skipping backward check. Detailed exception information: {e}.")
200
+ f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for "
201
+ f"{api_name_str}. Skipping backward check. Detailed exception information: {e}.")
197
202
  return Const.EXCEPTION_NONE
198
203
 
199
204
  backward_output_list = None
@@ -202,5 +207,6 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
202
207
  Const.BACKWARD)
203
208
  except Exception as e:
204
209
  logger.warning(
205
- f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} backward API. Detailed exception information: {e}.")
210
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
211
+ f"backward API. Detailed exception information: {e}.")
206
212
  return backward_output_list
@@ -17,7 +17,9 @@
17
17
  import multiprocessing
18
18
  import os
19
19
 
20
- from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header, get_result_csv_header, get_detail_csv_header, check_csv_header
20
+ from msprobe.mindspore.api_accuracy_checker.data_manager import (DataManager, ResultCsvEntry, write_csv_header,
21
+ get_result_csv_header, get_detail_csv_header,
22
+ check_csv_header)
21
23
  from msprobe.mindspore.common.log import logger
22
24
 
23
25