mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# 模型分级可视化如何配置layer mapping映射文件
|
|
2
|
+
|
|
3
|
+
## 1.使用场景
|
|
4
|
+
同框架跨套件比对(例如PyTorch DeepSpeed vs Megatron),或者跨框架比对(例如PyTorch vs MindSpore),**由于代码实现的差异,导致一些模型层级和层级命名有所不同无法进行匹配**,需要进行layer层名称映射,才能够比对。
|
|
5
|
+
|
|
6
|
+
## 2.模块命名说明
|
|
7
|
+
|
|
8
|
+
由于有些节点的名称比较长,例如Module.module.module.language_model.embedding.Embedding.forward.0,在图节点上由于字符串过长无法完整显示,forward或backward信息被省略,**因此节点中显示的名称字符串去掉了Module前缀,并将forward或backward信息提取到名称字符串的第二位展示**。
|
|
9
|
+
|
|
10
|
+

|
|
11
|
+
|
|
12
|
+

|
|
13
|
+
|
|
14
|
+
### 2.1 命名格式
|
|
15
|
+
|
|
16
|
+
**{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}**
|
|
17
|
+
|
|
18
|
+
**layer mapping主要是针对module_name的映射**
|
|
19
|
+
|
|
20
|
+
#### 2.1.1 命名示例
|
|
21
|
+
|
|
22
|
+
- **Module.module.Float16Module.forward.0** -----> Module{**Module**}.module{**module_name**}.Float16Module{**class_name**}.forward.0{**调用次数**}
|
|
23
|
+
- **Module.module.module.GPTModel.forward.0** -----> Module{**Module**}.module.module{**module_name**}.GPTModel{**class_name**}.forward.0{**调用次数**}
|
|
24
|
+
- **Module.module.module.language_model.TransformerLanguageModel.forward.0** -----> Module{**Module**}.module.module.language_model{**module_name**}.TransformerLanguageModel{**class_name**}.forward.0{**调用次数**}
|
|
25
|
+
- **Module.module.module.language_model.embedding.Embedding.forward.0** -----> Module{**Module**}.module.module.language_model.embedding{**module_name**}.Embedding{**class_name**}.forward.0{**调用次数**}
|
|
26
|
+
|
|
27
|
+
可以看到,module_name随着模型层级的深入在变长,**embedding层module_name拼接了它的上层language_model、上上层module和顶层module**。
|
|
28
|
+
|
|
29
|
+
## 3.示例
|
|
30
|
+
|
|
31
|
+
如图所示,左边为NPU模型,右边为GPU模型,由于代码实现上的差异,导致模型层级和层级命名有所不同,导致节点无法匹配,**图上节点显示为灰色,表示节点未匹配**。
|
|
32
|
+
|
|
33
|
+

|
|
34
|
+
|
|
35
|
+
### 3.1 看图分析
|
|
36
|
+
|
|
37
|
+
同一模型使用了不同套件或者框架,虽然两个模型的层级关系和层级命名可能有所不同,但也可以从图上的**节点名称**看出一些匹配关系,例如同是embedding层,代码里也是会命名为xxx_embedding,不会命名为xxx_norm,体现在节点名称上也是带有embedding的信息,并且层级关系也是大致相同的。
|
|
38
|
+
|
|
39
|
+

|
|
40
|
+
|
|
41
|
+
分析可知,节点匹配关系如下:
|
|
42
|
+
|
|
43
|
+
**注意,仅需关注module_name的差异**
|
|
44
|
+
|
|
45
|
+
| NPU节点名称 | GPU节点名称 | module_name差异 |
|
|
46
|
+
|-------------------|----------------------------------------------------------------|---------------------------|
|
|
47
|
+
| Module.module.Float16Module.forward.0 | Module.model.FloatModule.forward.0 | NPU为module,GPU为model |
|
|
48
|
+
| Module.module.module.GPTModel.forward.0 | Module.model.module.GPT2Model.forward.0 | NPU为module,GPU为module,无差异 |
|
|
49
|
+
| Module.module.module.language_model.TransformerLanguageModel.forward.0 | 无 | NPU多了一层 |
|
|
50
|
+
| Module.module.module.language_model.embedding.Embedding.forward.0 | Module.module.module.embedding.LanguageModelEmbedding.forward.0 | NPU为language_model.embedding,GPU为embedding |
|
|
51
|
+
| Module.module.module.language_model.rotary_pos_emb.RotaryEmbedding.forward.0 | Module.module.module.rotary_pos_emb.RotaryEmbedding.forward.0 | NPU为language_model.rotary_pos_emb,GPU为rotary_pos_emb |
|
|
52
|
+
| Module.module.module.language_model.encoder.ParallelTransformer.forward.0 | Module.module.module.decoder.TransformerBlock.forward.0 | NPU为language_model.encoder,GPU为decoder |
|
|
53
|
+
| Module.module.module.language_model.encoder.layers.0.ParallelTransformerLayer.forward.0 | Module.module.module.decoder.layers.0.TransformerLayer.forward.0 | 父层级有差异,本层级NPU和GPU都叫layers,无差异 |
|
|
54
|
+
|
|
55
|
+
### 3.2 构建layer_mapping配置文件
|
|
56
|
+
准备一个命名为mapping.yaml文件,建立**module_name**的映射关系
|
|
57
|
+
|
|
58
|
+
#### 3.2.1 顶层模块映射
|
|
59
|
+
NPU和GPU侧的模块Module.module.Float16Module.forward.0和Module.model.FloatModule.forward.0处于图的顶层,需要进行如下配置:
|
|
60
|
+
|
|
61
|
+

|
|
62
|
+
|
|
63
|
+
```yaml
|
|
64
|
+
TopLayer:
|
|
65
|
+
module: model
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
#### 3.2.2 其他模块映射
|
|
69
|
+
配置module下的子模块,虽然两边的class_name不同(NPU侧为GPTModel,GPU侧为GPT2Model),**但是仅需取NPU侧也就是左边图的class_name进行配置,无需关心右边图的class_name叫什么**。
|
|
70
|
+
|
|
71
|
+
**这里涉及到跨层级的配置,NPU多了一层language_model层**,将language_model作为embedding层、rotary_pos_emb层和encoder层的前缀,进行如下配置:
|
|
72
|
+
|
|
73
|
+

|
|
74
|
+
|
|
75
|
+
```yaml
|
|
76
|
+
GPTModel:
|
|
77
|
+
language_model.embedding: embedding
|
|
78
|
+
language_model.rotary_pos_emb: rotary_pos_emb
|
|
79
|
+
language_model.encoder: decoder
|
|
80
|
+
```
|
|
81
|
+
然后看Module.module.module.language_model.encoder.ParallelTransformer.forward.0层下的子模块:
|
|
82
|
+
|
|
83
|
+
此层下的若干个层,NPU和GPU的层名都叫layers,**当前层名称相同,则不用进行配置**。
|
|
84
|
+
|
|
85
|
+
### 3.3 查看效果
|
|
86
|
+
|
|
87
|
+
执行命令,指定-lm:
|
|
88
|
+
```
|
|
89
|
+
msprobe -f pytorch graph -i ./compare.json -o ./output -lm ./mapping.yaml
|
|
90
|
+
```
|
|
91
|
+
或
|
|
92
|
+
```
|
|
93
|
+
msprobe -f mindspore graph -i ./compare.json -o ./output -lm ./mapping.yaml
|
|
94
|
+
```
|
|
95
|
+
可以看到,除了language_model层(NPU多的一层,GPU没有层与其匹配),其余在mapping.yaml文件配置的层均匹配上了。
|
|
96
|
+
|
|
97
|
+

|
|
98
|
+
|
|
99
|
+
### 3.4 继续配置
|
|
100
|
+
|
|
101
|
+
展开节点过程中,如果发现还有未匹配节点,则继续配置mapping.yaml
|
|
102
|
+
|
|
103
|
+

|
|
104
|
+
|
|
105
|
+
按前一章过程进行分析配置,分析可知,节点匹配关系如下:
|
|
106
|
+
|
|
107
|
+
| NPU节点名称 | GPU节点名称 | 差异 |
|
|
108
|
+
|-------------------|------------------------------------------------------------------|---------------------------------------------|
|
|
109
|
+
| Module.module.module.language_model.encoder.layers.0.mlp.dense_h_to_4h.ColumnParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc1.TELayerNormColumnParallelLinear.forward.0 | NPU为dense_h_to_4h,GPU为linear_fc1 |
|
|
110
|
+
| Module.module.module.language_model.encoder.layers.0.mlp.dense_4h_to_h.RowParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc2.TERowParallelLinear.forward.0 | NPU为dense_4h_to_h,GPU为linear_fc2 |
|
|
111
|
+
|
|
112
|
+

|
|
113
|
+
|
|
114
|
+
追加mapping.yaml配置:
|
|
115
|
+
|
|
116
|
+
```yaml
|
|
117
|
+
TopLayer:
|
|
118
|
+
module: model
|
|
119
|
+
|
|
120
|
+
GPTModel:
|
|
121
|
+
language_model.embedding: embedding
|
|
122
|
+
language_model.rotary_pos_emb: rotary_pos_emb
|
|
123
|
+
language_model.encoder: decoder
|
|
124
|
+
|
|
125
|
+
ParallelMLP:
|
|
126
|
+
dense_h_to_4h: linear_fc1
|
|
127
|
+
dense_4h_to_h: linear_fc2
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
执行命令,查看效果,可以看到节点已成功匹配上。
|
|
131
|
+
|
|
132
|
+

|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -13,5 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from msprobe.lib import _msprobe_c
|
|
20
|
+
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
21
|
+
os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
|
|
22
|
+
except ImportError:
|
|
23
|
+
from .common.log import logger
|
|
24
|
+
logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
|
|
25
|
+
|
|
16
26
|
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
17
27
|
from msprobe.mindspore.common.utils import seed_all
|
|
@@ -30,6 +30,7 @@ from msprobe.mindspore.common.log import logger
|
|
|
30
30
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
31
|
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
32
32
|
|
|
33
|
+
|
|
33
34
|
class BasicInfoAndStatus:
|
|
34
35
|
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
35
36
|
self.api_name = api_name
|
|
@@ -49,6 +50,13 @@ class ResultCsvEntry:
|
|
|
49
50
|
self.overall_err_msg = None
|
|
50
51
|
|
|
51
52
|
|
|
53
|
+
class ProcessResultPacket:
|
|
54
|
+
def __init__(self, process_status, result, err_msg) -> None:
|
|
55
|
+
self.process_status = process_status
|
|
56
|
+
self.result = result
|
|
57
|
+
self.err_msg = err_msg
|
|
58
|
+
|
|
59
|
+
|
|
52
60
|
class ApiAccuracyChecker:
|
|
53
61
|
def __init__(self, args):
|
|
54
62
|
self.api_infos = dict()
|
|
@@ -56,7 +64,7 @@ class ApiAccuracyChecker:
|
|
|
56
64
|
|
|
57
65
|
@staticmethod
|
|
58
66
|
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
59
|
-
|
|
67
|
+
"""
|
|
60
68
|
Args:
|
|
61
69
|
api_info: ApiInfo
|
|
62
70
|
api_name_str: str
|
|
@@ -70,7 +78,7 @@ class ApiAccuracyChecker:
|
|
|
70
78
|
get mindspore api output, run torch api and get output.
|
|
71
79
|
compare output.
|
|
72
80
|
record compare result.
|
|
73
|
-
|
|
81
|
+
"""
|
|
74
82
|
# get output
|
|
75
83
|
if global_context.get_is_constructed():
|
|
76
84
|
# constructed situation, need use constructed input to run mindspore api getting tested_output
|
|
@@ -104,8 +112,8 @@ class ApiAccuracyChecker:
|
|
|
104
112
|
err_msg = ""
|
|
105
113
|
else:
|
|
106
114
|
status = CompareConst.ERROR
|
|
107
|
-
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
108
|
-
|
|
115
|
+
err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
116
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
|
|
109
117
|
basic_info_status = \
|
|
110
118
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
111
119
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
@@ -113,13 +121,13 @@ class ApiAccuracyChecker:
|
|
|
113
121
|
|
|
114
122
|
@staticmethod
|
|
115
123
|
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
116
|
-
|
|
124
|
+
"""
|
|
117
125
|
Args:
|
|
118
126
|
api_info: ApiInfo
|
|
119
127
|
forward_or_backward: str
|
|
120
128
|
Returns:
|
|
121
129
|
ApiInputAggregation
|
|
122
|
-
|
|
130
|
+
"""
|
|
123
131
|
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
124
132
|
kwargs = api_info.get_kwargs()
|
|
125
133
|
if forward_or_backward == Const.FORWARD:
|
|
@@ -162,7 +170,8 @@ class ApiAccuracyChecker:
|
|
|
162
170
|
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
163
171
|
if not is_constructed:
|
|
164
172
|
dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
|
|
165
|
-
"dump_data_dir field in api_info.json",
|
|
173
|
+
"dump_data_dir field in api_info.json",
|
|
174
|
+
accepted_type=str)
|
|
166
175
|
else:
|
|
167
176
|
dump_data_dir = ""
|
|
168
177
|
global_context.init(is_constructed, dump_data_dir)
|
|
@@ -188,45 +197,65 @@ class ApiAccuracyChecker:
|
|
|
188
197
|
"""处理前向检查"""
|
|
189
198
|
if not api_info.check_forward_info():
|
|
190
199
|
logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
|
|
191
|
-
|
|
200
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
|
|
201
|
+
result=None,
|
|
202
|
+
err_msg=f"forward info of {api_name_str} is not found")
|
|
203
|
+
return process_result_packet
|
|
192
204
|
|
|
193
205
|
try:
|
|
194
206
|
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
195
207
|
except Exception as e:
|
|
196
208
|
logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
|
|
197
209
|
f"Skipping forward check. Detailed exception information: {e}.")
|
|
198
|
-
|
|
210
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
211
|
+
result=None, err_msg=f"{e}")
|
|
212
|
+
return process_result_packet
|
|
199
213
|
|
|
200
|
-
forward_output_list = None
|
|
201
214
|
try:
|
|
202
|
-
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
215
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
216
|
+
Const.FORWARD)
|
|
203
217
|
except Exception as e:
|
|
204
218
|
logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
|
|
205
219
|
f"Detailed exception information: {e}.")
|
|
206
|
-
|
|
220
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
221
|
+
result=None, err_msg=f"{e}")
|
|
222
|
+
return process_result_packet
|
|
223
|
+
|
|
224
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
225
|
+
result=forward_output_list, err_msg="")
|
|
226
|
+
return process_result_packet
|
|
207
227
|
|
|
208
228
|
def process_backward(self, api_name_str, api_info):
|
|
209
229
|
"""处理反向检查"""
|
|
210
230
|
if not api_info.check_backward_info():
|
|
211
231
|
logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
|
|
212
|
-
|
|
232
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
|
|
233
|
+
result=None,
|
|
234
|
+
err_msg=f"backward info of {api_name_str} is not found")
|
|
235
|
+
return process_result_packet
|
|
213
236
|
|
|
214
237
|
try:
|
|
215
238
|
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
216
239
|
except Exception as e:
|
|
217
240
|
logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
|
|
218
241
|
f"Skipping backward check. Detailed exception information: {e}.")
|
|
219
|
-
|
|
242
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
243
|
+
result=None, err_msg=f"{e}")
|
|
244
|
+
return process_result_packet
|
|
220
245
|
|
|
221
|
-
backward_output_list = None
|
|
222
246
|
try:
|
|
223
|
-
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
247
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
248
|
+
Const.BACKWARD)
|
|
224
249
|
except Exception as e:
|
|
225
250
|
logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
|
|
226
251
|
f"Detailed exception information: {e}.")
|
|
227
|
-
|
|
228
|
-
|
|
252
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
253
|
+
result=None, err_msg=f"{e}")
|
|
254
|
+
return process_result_packet
|
|
229
255
|
|
|
256
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
257
|
+
result=backward_output_list, err_msg="")
|
|
258
|
+
return process_result_packet
|
|
230
259
|
|
|
231
260
|
def run_and_compare(self):
|
|
232
261
|
for api_name_str, api_info in tqdm(self.api_infos.items()):
|
|
@@ -234,14 +263,17 @@ class ApiAccuracyChecker:
|
|
|
234
263
|
continue
|
|
235
264
|
|
|
236
265
|
# 处理前向
|
|
237
|
-
|
|
238
|
-
if
|
|
239
|
-
self.data_manager.record(
|
|
266
|
+
process_result_packet = self.process_forward(api_name_str, api_info)
|
|
267
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
268
|
+
self.data_manager.record(process_result_packet.result)
|
|
269
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
270
|
+
self.data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
|
|
240
271
|
|
|
241
272
|
# 处理反向
|
|
242
|
-
|
|
243
|
-
if
|
|
244
|
-
self.data_manager.record(
|
|
273
|
+
process_result_packet = self.process_backward(api_name_str, api_info)
|
|
274
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
275
|
+
self.data_manager.record(process_result_packet.result)
|
|
276
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
277
|
+
self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
|
|
245
278
|
|
|
246
279
|
self.data_manager.save_results(api_name_str)
|
|
247
|
-
|
|
@@ -16,10 +16,10 @@
|
|
|
16
16
|
import argparse
|
|
17
17
|
import os
|
|
18
18
|
|
|
19
|
-
|
|
20
19
|
from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
|
|
21
20
|
from msprobe.core.common.utils import Const, MsprobeBaseException
|
|
22
21
|
|
|
22
|
+
|
|
23
23
|
class UniqueDeviceAction(argparse.Action):
|
|
24
24
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
25
25
|
unique_values = set(values)
|
|
@@ -40,6 +40,7 @@ def add_api_accuracy_checker_argument(parser):
|
|
|
40
40
|
parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
|
|
41
41
|
help="<optional> the exit csv for continue")
|
|
42
42
|
|
|
43
|
+
|
|
43
44
|
def multi_add_api_accuracy_checker_argument(parser):
|
|
44
45
|
parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
|
|
45
46
|
help="<Required> The api param tool result file: generate from api param tool, "
|
|
@@ -78,12 +78,10 @@ class ComputeElement:
|
|
|
78
78
|
else:
|
|
79
79
|
torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
|
|
80
80
|
|
|
81
|
-
if dtype_str in
|
|
82
|
-
middle_dtype = mindspore.float64
|
|
83
|
-
elif dtype_str in int_dtype_str_list:
|
|
81
|
+
if dtype_str in int_dtype_str_list:
|
|
84
82
|
middle_dtype = mindspore.int64
|
|
85
83
|
else:
|
|
86
|
-
middle_dtype = mindspore.
|
|
84
|
+
middle_dtype = mindspore.float64
|
|
87
85
|
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
88
86
|
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
89
87
|
return torch_tensor
|
|
@@ -106,10 +104,10 @@ class ComputeElement:
|
|
|
106
104
|
else:
|
|
107
105
|
ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
|
|
108
106
|
|
|
109
|
-
if dtype_str in
|
|
110
|
-
middle_dtype = torch.float64
|
|
111
|
-
elif dtype_str in int_dtype_str_list:
|
|
107
|
+
if dtype_str in int_dtype_str_list:
|
|
112
108
|
middle_dtype = torch.int64
|
|
109
|
+
else:
|
|
110
|
+
middle_dtype = torch.float64
|
|
113
111
|
np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
|
|
114
112
|
ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
|
|
115
113
|
return ms_tensor
|
|
@@ -80,6 +80,7 @@ def check_csv_header(headers, required_constants, csv_path):
|
|
|
80
80
|
class DataManager:
|
|
81
81
|
def __init__(self, csv_dir, result_csv_path):
|
|
82
82
|
self.results = {}
|
|
83
|
+
self.results_exception_skip = {}
|
|
83
84
|
self.is_first_write = True # 标记用于添加表头
|
|
84
85
|
self.csv_dir = csv_dir
|
|
85
86
|
self.api_names_set = set() # 存储已经出现的 API 名称的集合
|
|
@@ -184,10 +185,21 @@ class DataManager:
|
|
|
184
185
|
logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
|
|
185
186
|
logger.debug(f"Complete self.results after recording: {self.results}")
|
|
186
187
|
|
|
188
|
+
def record_exception_skip(self, api_name, forward_or_backward, err_msg):
|
|
189
|
+
'''
|
|
190
|
+
record exception_skip infomation into self.record_exception_skip.
|
|
191
|
+
self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
|
|
192
|
+
string in key is api_name, string in value is err_msg
|
|
193
|
+
'''
|
|
194
|
+
if api_name not in self.results_exception_skip:
|
|
195
|
+
self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
|
|
196
|
+
self.results_exception_skip[api_name][forward_or_backward] = err_msg
|
|
197
|
+
|
|
187
198
|
def clear_results(self):
|
|
188
199
|
"""清空 self.results 数据"""
|
|
189
200
|
logger.debug("Clearing self.results data.")
|
|
190
201
|
self.results.clear()
|
|
202
|
+
self.results_exception_skip.clear()
|
|
191
203
|
|
|
192
204
|
def to_detail_csv(self, csv_path):
|
|
193
205
|
logger.debug("Preparing detail CSV headers and rows.")
|
|
@@ -218,6 +230,9 @@ class DataManager:
|
|
|
218
230
|
logger.debug(f"Detail CSV written successfully to {csv_path}.")
|
|
219
231
|
|
|
220
232
|
def to_result_csv(self, csv_path):
|
|
233
|
+
'''
|
|
234
|
+
depend on both self.results and self.results_exception_skip
|
|
235
|
+
'''
|
|
221
236
|
logger.debug("Preparing result CSV data.")
|
|
222
237
|
result_csv = []
|
|
223
238
|
|
|
@@ -254,8 +269,30 @@ class DataManager:
|
|
|
254
269
|
entry.backward_pass_status,
|
|
255
270
|
overall_err_msg
|
|
256
271
|
]
|
|
272
|
+
# change row if this api has excption_skip infomation
|
|
273
|
+
if api_name in self.results_exception_skip:
|
|
274
|
+
if self.results_exception_skip[api_name][Const.FORWARD] is not None:
|
|
275
|
+
row[1] = CompareConst.SKIP
|
|
276
|
+
row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
|
|
277
|
+
if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
|
|
278
|
+
row[2] = CompareConst.SKIP
|
|
279
|
+
row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
|
|
280
|
+
del self.results_exception_skip[api_name]
|
|
257
281
|
result_csv.append(row)
|
|
258
282
|
logger.debug(f"Result CSV row added: {row}")
|
|
283
|
+
for api_name in self.results_exception_skip:
|
|
284
|
+
current_exception_skip = self.results_exception_skip[api_name]
|
|
285
|
+
forward_status = None
|
|
286
|
+
backward_status = None
|
|
287
|
+
err_msg = ""
|
|
288
|
+
if current_exception_skip[Const.FORWARD] is not None:
|
|
289
|
+
forward_status = CompareConst.SKIP
|
|
290
|
+
err_msg += current_exception_skip[Const.FORWARD]
|
|
291
|
+
if current_exception_skip[Const.BACKWARD] is not None:
|
|
292
|
+
backward_status = CompareConst.SKIP
|
|
293
|
+
err_msg += current_exception_skip[Const.BACKWARD]
|
|
294
|
+
row = [api_name, forward_status, backward_status, err_msg]
|
|
295
|
+
result_csv.append(row)
|
|
259
296
|
|
|
260
297
|
write_csv(result_csv, csv_path, mode="a+")
|
|
261
298
|
logger.debug(f"Result CSV written successfully to {csv_path}.")
|
|
@@ -154,14 +154,16 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
154
154
|
"""
|
|
155
155
|
if not api_info.check_forward_info():
|
|
156
156
|
logger.debug(
|
|
157
|
-
f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping
|
|
157
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping "
|
|
158
|
+
f"forward check.")
|
|
158
159
|
return Const.EXCEPTION_NONE
|
|
159
160
|
|
|
160
161
|
try:
|
|
161
162
|
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
162
163
|
except Exception as e:
|
|
163
164
|
logger.warning(
|
|
164
|
-
f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for
|
|
165
|
+
f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for "
|
|
166
|
+
f"{api_name_str}. Skipping forward check. Detailed exception information: {e}.")
|
|
165
167
|
return Const.EXCEPTION_NONE
|
|
166
168
|
|
|
167
169
|
forward_output_list = None
|
|
@@ -170,7 +172,8 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
170
172
|
Const.FORWARD)
|
|
171
173
|
except Exception as e:
|
|
172
174
|
logger.warning(
|
|
173
|
-
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str}
|
|
175
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
|
|
176
|
+
f"forward API. Detailed exception information: {e}.")
|
|
174
177
|
return forward_output_list
|
|
175
178
|
|
|
176
179
|
def process_backward(self, api_name_str, api_info):
|
|
@@ -186,14 +189,16 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
186
189
|
"""
|
|
187
190
|
if not api_info.check_backward_info():
|
|
188
191
|
logger.debug(
|
|
189
|
-
f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping
|
|
192
|
+
f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping "
|
|
193
|
+
f"backward check.")
|
|
190
194
|
return Const.EXCEPTION_NONE
|
|
191
195
|
|
|
192
196
|
try:
|
|
193
197
|
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
194
198
|
except Exception as e:
|
|
195
199
|
logger.warning(
|
|
196
|
-
f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for
|
|
200
|
+
f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for "
|
|
201
|
+
f"{api_name_str}. Skipping backward check. Detailed exception information: {e}.")
|
|
197
202
|
return Const.EXCEPTION_NONE
|
|
198
203
|
|
|
199
204
|
backward_output_list = None
|
|
@@ -202,5 +207,6 @@ class MultiApiAccuracyChecker(ApiAccuracyChecker):
|
|
|
202
207
|
Const.BACKWARD)
|
|
203
208
|
except Exception as e:
|
|
204
209
|
logger.warning(
|
|
205
|
-
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str}
|
|
210
|
+
f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} "
|
|
211
|
+
f"backward API. Detailed exception information: {e}.")
|
|
206
212
|
return backward_output_list
|
|
@@ -17,7 +17,9 @@
|
|
|
17
17
|
import multiprocessing
|
|
18
18
|
import os
|
|
19
19
|
|
|
20
|
-
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header,
|
|
20
|
+
from msprobe.mindspore.api_accuracy_checker.data_manager import (DataManager, ResultCsvEntry, write_csv_header,
|
|
21
|
+
get_result_csv_header, get_detail_csv_header,
|
|
22
|
+
check_csv_header)
|
|
21
23
|
from msprobe.mindspore.common.log import logger
|
|
22
24
|
|
|
23
25
|
|