mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# MindSpore 场景的 kernel dump 说明
|
|
2
|
+
|
|
3
|
+
当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。
|
|
4
|
+
|
|
5
|
+
本文主要介绍 kernel dump 的配置示例和采集结果介绍, msprobe 数据采集功能的详细使用参考 《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》。
|
|
6
|
+
|
|
7
|
+
## 1 kernel dump 配置示例
|
|
8
|
+
|
|
9
|
+
使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。
|
|
10
|
+
API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。
|
|
11
|
+
|
|
12
|
+
```json
|
|
13
|
+
{
|
|
14
|
+
"task": "tensor",
|
|
15
|
+
"dump_path": "/home/data_dump",
|
|
16
|
+
"level": "L2",
|
|
17
|
+
"rank": [],
|
|
18
|
+
"step": [],
|
|
19
|
+
"tensor": {
|
|
20
|
+
"scope": [],
|
|
21
|
+
"list": ["Functional.linear.0.backward"]
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## 2 结果文件介绍
|
|
27
|
+
|
|
28
|
+
### 2.1 采集结果说明
|
|
29
|
+
|
|
30
|
+
如果 API kernel 级数据采集成功,会打印以下信息:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
The kernel data of {api_name} is dumped successfully.
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
注意:如果打印该信息后,没有数据生成,参考**常见问题3.1**进行排查。
|
|
37
|
+
|
|
38
|
+
如果 kernel dump 遇到不支持的 API, 会打印以下信息:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
The kernel dump does not support the {api_name} API.
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
其中 {api_name} 是对应溢出的 API 名称。
|
|
45
|
+
|
|
46
|
+
### 2.2 输出文件说明
|
|
47
|
+
kernel dump 采集成功后,会在指定的 dump_path 目录下生成如下文件:
|
|
48
|
+
|
|
49
|
+
```
|
|
50
|
+
├── /home/data_dump/
|
|
51
|
+
│ ├── step0
|
|
52
|
+
│ │ ├── 20241201103000 # 日期时间格式,表示2024-12-01 10:30:00
|
|
53
|
+
│ │ │ ├── 0 # 表示 device id
|
|
54
|
+
│ │ │ │ ├──{op_type}.{op_name}.{task_id}.{stream_id}.{timestamp} # kernel 层算子数据
|
|
55
|
+
│ │ │ ...
|
|
56
|
+
│ │ ├── kernel_config_{device_id}.json # kernel dump 在接口调用过程中生成的中间文件,一般情况下无需关注
|
|
57
|
+
│ │ ...
|
|
58
|
+
│ ├── step1
|
|
59
|
+
│ ...
|
|
60
|
+
```
|
|
61
|
+
成功采集到数据后,可以使用 msprobe 工具提供的《[PyTorch 场景的数据解析](./14.data_parse_PyTorch.md)》功能分析数据。
|
|
62
|
+
|
|
63
|
+
## 3 常见问题
|
|
64
|
+
|
|
65
|
+
#### 3.1 采集结果文件为空,有可能是什么原因?
|
|
66
|
+
|
|
67
|
+
1. 首先需要确认工具使用方式、配置文件内容、list 填写的 API 名称格式是否都正确无误。
|
|
68
|
+
|
|
69
|
+
2. 其次需要确认 API 是否运行在昇腾 NPU 上,如果是运行在其他设备上则不会存在 kernel 级数据。
|
msprobe/docs/FAQ.md
CHANGED
|
@@ -13,6 +13,29 @@
|
|
|
13
13
|
2. 如果存在namedtuple类型的数据作为nn.Module的输出,工具会将各字段数据dump下来,但是输出数据类型会被转成tuple,原因是什么?
|
|
14
14
|
- 这是由于pytorch框架自身,在注册module的backward hook时,会将namedtuple类型转成tuple类型。
|
|
15
15
|
|
|
16
|
+
3. 如果某个api在dump支持列表support_wrap_ops.yaml中,但没有dump该api的数据,原因是什么?
|
|
17
|
+
- 首先确认api调用是否在采集范围内,即需要在 **start** 和 **stop** 接口涵盖的范围内。
|
|
18
|
+
- 其次,由于工具只在被调用时才对api进行patch,从而使得数据可以被dump下来。因此当api是被直接import进行调用时,由于该api的地址已经确定,
|
|
19
|
+
工具无法再对其进行patch,故而该api数据无法被dump下来。如下示例,relu将无法被dump:
|
|
20
|
+
```python
|
|
21
|
+
import torch
|
|
22
|
+
from torch import relu # 此时relu地址已经确定,无法修改
|
|
23
|
+
|
|
24
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
25
|
+
|
|
26
|
+
debugger = PrecisionDebugger(dump_path="./dump_data")
|
|
27
|
+
x = torch.randn(10)
|
|
28
|
+
debugger.start() # 此时会对torch下面的api进行patch,但已无法对import进来的api进行patch了
|
|
29
|
+
x = relu(x)
|
|
30
|
+
debugger.stop()
|
|
31
|
+
```
|
|
32
|
+
在上述场景中,若希望采集relu数据,只需要将`relu(x)`修改为`torch.relu(x)`即可。
|
|
33
|
+
|
|
34
|
+
4. 在使用L0 dump时,发现有些 module 的数据没有采集下来,原因是什么?
|
|
35
|
+
- 确认日志打印中是否存在`The {module_name} has registered deprecated register_backward_hook`信息,
|
|
36
|
+
该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。
|
|
37
|
+
- 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。
|
|
38
|
+
|
|
16
39
|
# 2 精度预检(PyTorch)
|
|
17
40
|
|
|
18
41
|
1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)?
|
|
@@ -183,9 +206,10 @@ def npu_forward_fused_softmax(self, input_, mask):
|
|
|
183
206
|
|
|
184
207
|
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
185
208
|
|
|
186
|
-
11.
|
|
209
|
+
11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
|
|
187
210
|
|
|
188
|
-
|
|
211
|
+
答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。
|
|
212
|
+
规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
189
213
|
|
|
190
214
|
12. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
|
|
191
215
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# MindSpore 场景的精度预检基线
|
|
2
|
+
|
|
3
|
+
## "multi_run_ut"模式精度预检耗时参考基线
|
|
4
|
+
|
|
5
|
+
该基线为MindSpore框架下,使用"multi_run_ut"模式精度预检耗时参考基线。本基线测试了38B语言大模型在不同卡数下耗时的变化。
|
|
6
|
+
|
|
7
|
+
### 38B语言大模型
|
|
8
|
+
|
|
9
|
+
| 卡数 | 总耗时 (分钟) | 备注 |
|
|
10
|
+
| ----- |----------|---------- |
|
|
11
|
+
| 1 卡 | 21.0 | 单卡基线 |
|
|
12
|
+
| 2 卡 | 11.5 | 双卡基线 |
|
|
13
|
+
| 4 卡 | 6.7 | 四卡基线 |
|
|
14
|
+
| 8 卡 | 3.5 | 八卡基线 |
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# MindSpore 场景的精度数据采集基线
|
|
2
|
+
|
|
3
|
+
## "tensor"模式采集数据量参考基线
|
|
4
|
+
|
|
5
|
+
该基线为MindSpore框架下,使用"tensor"模式采集数据量参考基线。本基线测试了38B语言大模型在不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。
|
|
6
|
+
|
|
7
|
+
### 38B语言大模型
|
|
8
|
+
|
|
9
|
+
<table>
|
|
10
|
+
<tr><th>采集模式</th><th>global_batch_size</th><th>单卡</th><th>8卡</th></tr>
|
|
11
|
+
</td><td rowspan="3">L0</td><td>1</td><td>262GB</td><td>2.1T</td></tr>
|
|
12
|
+
<tr><td>2</td><td>480GB</td><td>3.8T</td></tr>
|
|
13
|
+
<tr><td>3</td><td>928GB</td><td>7.4T</td></tr>
|
|
14
|
+
</td><td rowspan="3">L1</td><td>1</td><td>2.1TB</td><td>17.1TB</td></tr>
|
|
15
|
+
<tr><td>2</td><td>2.8T</td><td>22.7TB</td></tr>
|
|
16
|
+
<tr><td>3</td><td>4.2T</td><td>34.3TB</td></tr>
|
|
17
|
+
</td><td rowspan="3">mix</td><td>1</td><td>2.4T</td><td>19.2TB</td></tr>
|
|
18
|
+
<tr><td>2</td><td>3.3TB</td><td>26.6TB</td></tr>
|
|
19
|
+
<tr><td>3</td><td>5.1TB</td><td>41.4TB</td></tr>
|
|
20
|
+
|
|
21
|
+
</table>
|
|
22
|
+
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# 模型分级可视化如何配置layer mapping映射文件
|
|
2
|
+
|
|
3
|
+
## 1.使用场景
|
|
4
|
+
同框架跨套件比对(例如PyTorch DeepSpeed vs Megatron),或者跨框架比对(例如PyTorch vs MindSpore),**由于代码实现的差异,导致一些模型层级和层级命名有所不同无法进行匹配**,需要进行layer层名称映射,才能够比对。
|
|
5
|
+
|
|
6
|
+
## 2.模块命名说明
|
|
7
|
+
|
|
8
|
+
由于有些节点的名称比较长,例如Module.module.module.language_model.embedding.Embedding.forward.0,在图节点上由于字符串过长无法完整显示,forward或backward信息被省略,**因此节点中显示的名称字符串去掉了Module前缀,并将forward或backward信息提取到名称字符串的第二位展示**。
|
|
9
|
+
|
|
10
|
+

|
|
11
|
+
|
|
12
|
+

|
|
13
|
+
|
|
14
|
+
### 2.1 命名格式
|
|
15
|
+
|
|
16
|
+
**{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}**
|
|
17
|
+
|
|
18
|
+
**layer mapping主要是针对module_name的映射**
|
|
19
|
+
|
|
20
|
+
#### 2.1.1 命名示例
|
|
21
|
+
|
|
22
|
+
- **Module.module.Float16Module.forward.0** -----> Module{**Module**}.module{**module_name**}.Float16Module{**class_name**}.forward.0{**调用次数**}
|
|
23
|
+
- **Module.module.module.GPTModel.forward.0** -----> Module{**Module**}.module.module{**module_name**}.GPTModel{**class_name**}.forward.0{**调用次数**}
|
|
24
|
+
- **Module.module.module.language_model.TransformerLanguageModel.forward.0** -----> Module{**Module**}.module.module.language_model{**module_name**}.TransformerLanguageModel{**class_name**}.forward.0{**调用次数**}
|
|
25
|
+
- **Module.module.module.language_model.embedding.Embedding.forward.0** -----> Module{**Module**}.module.module.language_model.embedding{**module_name**}.Embedding{**class_name**}.forward.0{**调用次数**}
|
|
26
|
+
|
|
27
|
+
可以看到,module_name随着模型层级的深入在变长,**embedding层module_name拼接了它的上层language_model、上上层module和顶层module**。
|
|
28
|
+
|
|
29
|
+
## 3.示例
|
|
30
|
+
|
|
31
|
+
如图所示,左边为NPU模型,右边为GPU模型,由于代码实现上的差异,导致模型层级和层级命名有所不同,导致节点无法匹配,**图上节点显示为灰色,表示节点未匹配**。
|
|
32
|
+
|
|
33
|
+

|
|
34
|
+
|
|
35
|
+
### 3.1 看图分析
|
|
36
|
+
|
|
37
|
+
同一模型使用了不同套件或者框架,虽然两个模型的层级关系和层级命名可能有所不同,但也可以从图上的**节点名称**看出一些匹配关系,例如同是embedding层,代码里也是会命名为xxx_embedding,不会命名为xxx_norm,体现在节点名称上也是带有embedding的信息,并且层级关系也是大致相同的。
|
|
38
|
+
|
|
39
|
+

|
|
40
|
+
|
|
41
|
+
分析可知,节点匹配关系如下:
|
|
42
|
+
|
|
43
|
+
**注意,仅需关注module_name的差异**
|
|
44
|
+
|
|
45
|
+
| NPU节点名称 | GPU节点名称 | module_name差异 |
|
|
46
|
+
|-------------------|----------------------------------------------------------------|---------------------------|
|
|
47
|
+
| Module.module.Float16Module.forward.0 | Module.model.FloatModule.forward.0 | NPU为module,GPU为model |
|
|
48
|
+
| Module.module.module.GPTModel.forward.0 | Module.model.module.GPT2Model.forward.0 | NPU为module,GPU为module,无差异 |
|
|
49
|
+
| Module.module.module.language_model.TransformerLanguageModel.forward.0 | 无 | NPU多了一层 |
|
|
50
|
+
| Module.module.module.language_model.embedding.Embedding.forward.0 | Module.module.module.embedding.LanguageModelEmbedding.forward.0 | NPU为language_model.embedding,GPU为embedding |
|
|
51
|
+
| Module.module.module.language_model.rotary_pos_emb.RotaryEmbedding.forward.0 | Module.module.module.rotary_pos_emb.RotaryEmbedding.forward.0 | NPU为language_model.rotary_pos_emb,GPU为rotary_pos_emb |
|
|
52
|
+
| Module.module.module.language_model.encoder.ParallelTransformer.forward.0 | Module.module.module.decoder.TransformerBlock.forward.0 | NPU为language_model.encoder,GPU为decoder |
|
|
53
|
+
| Module.module.module.language_model.encoder.layers.0.ParallelTransformerLayer.forward.0 | Module.module.module.decoder.layers.0.TransformerLayer.forward.0 | 父层级有差异,本层级NPU和GPU都叫layers,无差异 |
|
|
54
|
+
|
|
55
|
+
### 3.2 构建layer_mapping配置文件
|
|
56
|
+
准备一个命名为mapping.yaml文件,建立**module_name**的映射关系
|
|
57
|
+
|
|
58
|
+
#### 3.2.1 顶层模块映射
|
|
59
|
+
NPU和GPU侧的模块Module.module.Float16Module.forward.0和Module.model.FloatModule.forward.0处于图的顶层,需要进行如下配置:
|
|
60
|
+
|
|
61
|
+

|
|
62
|
+
|
|
63
|
+
```yaml
|
|
64
|
+
TopLayer:
|
|
65
|
+
module: model
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
#### 3.2.2 其他模块映射
|
|
69
|
+
配置module下的子模块,虽然两边的class_name不同(NPU侧为GPTModel,GPU侧为GPT2Model),**但是仅需取NPU侧也就是左边图的class_name进行配置,无需关心右边图的class_name叫什么**。
|
|
70
|
+
|
|
71
|
+
**这里涉及到跨层级的配置,NPU多了一层language_model层**,将language_model作为embedding层、rotary_pos_emb层和encoder层的前缀,进行如下配置:
|
|
72
|
+
|
|
73
|
+

|
|
74
|
+
|
|
75
|
+
```yaml
|
|
76
|
+
GPTModel:
|
|
77
|
+
language_model.embedding: embedding
|
|
78
|
+
language_model.rotary_pos_emb: rotary_pos_emb
|
|
79
|
+
language_model.encoder: decoder
|
|
80
|
+
```
|
|
81
|
+
然后看Module.module.module.language_model.encoder.ParallelTransformer.forward.0层下的子模块:
|
|
82
|
+
|
|
83
|
+
此层下的若干个层,NPU和GPU的层名都叫layers,**当前层名称相同,则不用进行配置**。
|
|
84
|
+
|
|
85
|
+
### 3.3 查看效果
|
|
86
|
+
|
|
87
|
+
执行命令,指定-lm:
|
|
88
|
+
```
|
|
89
|
+
msprobe -f pytorch graph -i ./compare.json -o ./output -lm ./mapping.yaml
|
|
90
|
+
```
|
|
91
|
+
或
|
|
92
|
+
```
|
|
93
|
+
msprobe -f mindspore graph -i ./compare.json -o ./output -lm ./mapping.yaml
|
|
94
|
+
```
|
|
95
|
+
可以看到,除了language_model层(NPU多的一层,GPU没有层与其匹配),其余在mapping.yaml文件配置的层均匹配上了。
|
|
96
|
+
|
|
97
|
+

|
|
98
|
+
|
|
99
|
+
### 3.4 继续配置
|
|
100
|
+
|
|
101
|
+
展开节点过程中,如果发现还有未匹配节点,则继续配置mapping.yaml
|
|
102
|
+
|
|
103
|
+

|
|
104
|
+
|
|
105
|
+
按前一章过程进行分析配置,分析可知,节点匹配关系如下:
|
|
106
|
+
|
|
107
|
+
| NPU节点名称 | GPU节点名称 | 差异 |
|
|
108
|
+
|-------------------|------------------------------------------------------------------|---------------------------------------------|
|
|
109
|
+
| Module.module.module.language_model.encoder.layers.0.mlp.dense_h_to_4h.ColumnParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc1.TELayerNormColumnParallelLinear.forward.0 | NPU为dense_h_to_4h,GPU为linear_fc1 |
|
|
110
|
+
| Module.module.module.language_model.encoder.layers.0.mlp.dense_4h_to_h.RowParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc2.TERowParallelLinear.forward.0 | NPU为dense_4h_to_h,GPU为linear_fc2 |
|
|
111
|
+
|
|
112
|
+

|
|
113
|
+
|
|
114
|
+
追加mapping.yaml配置:
|
|
115
|
+
|
|
116
|
+
```yaml
|
|
117
|
+
TopLayer:
|
|
118
|
+
module: model
|
|
119
|
+
|
|
120
|
+
GPTModel:
|
|
121
|
+
language_model.embedding: embedding
|
|
122
|
+
language_model.rotary_pos_emb: rotary_pos_emb
|
|
123
|
+
language_model.encoder: decoder
|
|
124
|
+
|
|
125
|
+
ParallelMLP:
|
|
126
|
+
dense_h_to_4h: linear_fc1
|
|
127
|
+
dense_4h_to_h: linear_fc2
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
执行命令,查看效果,可以看到节点已成功匹配上。
|
|
131
|
+
|
|
132
|
+

|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -13,5 +13,16 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from msprobe.lib import _msprobe_c
|
|
20
|
+
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
21
|
+
os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
|
|
22
|
+
except ImportError:
|
|
23
|
+
from .common.log import logger
|
|
24
|
+
logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
|
|
25
|
+
|
|
16
26
|
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
17
27
|
from msprobe.mindspore.common.utils import seed_all
|
|
28
|
+
from msprobe.mindspore.monitor.module_hook import TrainerMon
|
|
@@ -26,10 +26,12 @@ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
|
26
26
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
27
27
|
trim_output_compute_element_list)
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
29
30
|
|
|
30
31
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
32
|
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
32
33
|
|
|
34
|
+
|
|
33
35
|
class BasicInfoAndStatus:
|
|
34
36
|
def __init__(self, api_name, bench_dtype, tested_dtype, shape, status, err_msg) -> None:
|
|
35
37
|
self.api_name = api_name
|
|
@@ -49,6 +51,13 @@ class ResultCsvEntry:
|
|
|
49
51
|
self.overall_err_msg = None
|
|
50
52
|
|
|
51
53
|
|
|
54
|
+
class ProcessResultPacket:
|
|
55
|
+
def __init__(self, process_status, result, err_msg) -> None:
|
|
56
|
+
self.process_status = process_status
|
|
57
|
+
self.result = result
|
|
58
|
+
self.err_msg = err_msg
|
|
59
|
+
|
|
60
|
+
|
|
52
61
|
class ApiAccuracyChecker:
|
|
53
62
|
def __init__(self, args):
|
|
54
63
|
self.api_infos = dict()
|
|
@@ -56,7 +65,7 @@ class ApiAccuracyChecker:
|
|
|
56
65
|
|
|
57
66
|
@staticmethod
|
|
58
67
|
def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
59
|
-
|
|
68
|
+
"""
|
|
60
69
|
Args:
|
|
61
70
|
api_info: ApiInfo
|
|
62
71
|
api_name_str: str
|
|
@@ -70,13 +79,15 @@ class ApiAccuracyChecker:
|
|
|
70
79
|
get mindspore api output, run torch api and get output.
|
|
71
80
|
compare output.
|
|
72
81
|
record compare result.
|
|
73
|
-
|
|
82
|
+
"""
|
|
74
83
|
# get output
|
|
75
84
|
if global_context.get_is_constructed():
|
|
76
85
|
# constructed situation, need use constructed input to run mindspore api getting tested_output
|
|
77
|
-
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
86
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
87
|
+
forward_or_backward, global_context.get_framework())
|
|
78
88
|
else:
|
|
79
89
|
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
90
|
+
|
|
80
91
|
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
81
92
|
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
82
93
|
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
@@ -104,8 +115,8 @@ class ApiAccuracyChecker:
|
|
|
104
115
|
err_msg = ""
|
|
105
116
|
else:
|
|
106
117
|
status = CompareConst.ERROR
|
|
107
|
-
err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
108
|
-
|
|
118
|
+
err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
119
|
+
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
|
|
109
120
|
basic_info_status = \
|
|
110
121
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
111
122
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
@@ -113,13 +124,13 @@ class ApiAccuracyChecker:
|
|
|
113
124
|
|
|
114
125
|
@staticmethod
|
|
115
126
|
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
116
|
-
|
|
127
|
+
"""
|
|
117
128
|
Args:
|
|
118
129
|
api_info: ApiInfo
|
|
119
130
|
forward_or_backward: str
|
|
120
131
|
Returns:
|
|
121
132
|
ApiInputAggregation
|
|
122
|
-
|
|
133
|
+
"""
|
|
123
134
|
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
124
135
|
kwargs = api_info.get_kwargs()
|
|
125
136
|
if forward_or_backward == Const.FORWARD:
|
|
@@ -145,13 +156,19 @@ class ApiAccuracyChecker:
|
|
|
145
156
|
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
146
157
|
api_list = load_yaml(yaml_path)
|
|
147
158
|
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
148
|
-
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
159
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
160
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
149
161
|
return True
|
|
150
|
-
if api_type_str
|
|
162
|
+
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
163
|
+
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
164
|
+
return True
|
|
165
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
166
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
151
167
|
return True
|
|
152
168
|
return False
|
|
153
169
|
|
|
154
170
|
def parse(self, api_info_path):
|
|
171
|
+
|
|
155
172
|
api_info_dict = load_json(api_info_path)
|
|
156
173
|
|
|
157
174
|
# init global context
|
|
@@ -159,13 +176,25 @@ class ApiAccuracyChecker:
|
|
|
159
176
|
"task field in api_info.json", accepted_type=str,
|
|
160
177
|
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
161
178
|
MsCompareConst.TENSOR_TASK))
|
|
179
|
+
try:
|
|
180
|
+
framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
|
|
181
|
+
"framework field in api_info.json", accepted_type=str,
|
|
182
|
+
accepted_value=(Const.MS_FRAMEWORK,
|
|
183
|
+
Const.MT_FRAMEWORK))
|
|
184
|
+
except Exception as e:
|
|
185
|
+
framework = Const.MS_FRAMEWORK
|
|
186
|
+
logger.warning(f"JSON parsing error in framework field: {e}")
|
|
187
|
+
|
|
188
|
+
if framework == Const.MT_FRAMEWORK and not torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
189
|
+
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment")
|
|
190
|
+
|
|
162
191
|
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
163
192
|
if not is_constructed:
|
|
164
193
|
dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
|
|
165
194
|
"dump_data_dir field in api_info.json", accepted_type=str)
|
|
166
195
|
else:
|
|
167
196
|
dump_data_dir = ""
|
|
168
|
-
global_context.init(is_constructed, dump_data_dir)
|
|
197
|
+
global_context.init(is_constructed, dump_data_dir, framework)
|
|
169
198
|
|
|
170
199
|
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
171
200
|
"data field in api_info.json", accepted_type=dict)
|
|
@@ -188,45 +217,65 @@ class ApiAccuracyChecker:
|
|
|
188
217
|
"""处理前向检查"""
|
|
189
218
|
if not api_info.check_forward_info():
|
|
190
219
|
logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
|
|
191
|
-
|
|
220
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
|
|
221
|
+
result=None,
|
|
222
|
+
err_msg=f"forward info of {api_name_str} is not found")
|
|
223
|
+
return process_result_packet
|
|
192
224
|
|
|
193
225
|
try:
|
|
194
226
|
forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
|
|
195
227
|
except Exception as e:
|
|
196
228
|
logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
|
|
197
229
|
f"Skipping forward check. Detailed exception information: {e}.")
|
|
198
|
-
|
|
230
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
231
|
+
result=None, err_msg=f"{e}")
|
|
232
|
+
return process_result_packet
|
|
199
233
|
|
|
200
|
-
forward_output_list = None
|
|
201
234
|
try:
|
|
202
|
-
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
235
|
+
forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
|
|
236
|
+
Const.FORWARD)
|
|
203
237
|
except Exception as e:
|
|
204
238
|
logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
|
|
205
239
|
f"Detailed exception information: {e}.")
|
|
206
|
-
|
|
240
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
241
|
+
result=None, err_msg=f"{e}")
|
|
242
|
+
return process_result_packet
|
|
243
|
+
|
|
244
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
245
|
+
result=forward_output_list, err_msg="")
|
|
246
|
+
return process_result_packet
|
|
207
247
|
|
|
208
248
|
def process_backward(self, api_name_str, api_info):
|
|
209
249
|
"""处理反向检查"""
|
|
210
250
|
if not api_info.check_backward_info():
|
|
211
251
|
logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
|
|
212
|
-
|
|
252
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
|
|
253
|
+
result=None,
|
|
254
|
+
err_msg=f"backward info of {api_name_str} is not found")
|
|
255
|
+
return process_result_packet
|
|
213
256
|
|
|
214
257
|
try:
|
|
215
258
|
backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
|
|
216
259
|
except Exception as e:
|
|
217
260
|
logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
|
|
218
261
|
f"Skipping backward check. Detailed exception information: {e}.")
|
|
219
|
-
|
|
262
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
263
|
+
result=None, err_msg=f"{e}")
|
|
264
|
+
return process_result_packet
|
|
220
265
|
|
|
221
|
-
backward_output_list = None
|
|
222
266
|
try:
|
|
223
|
-
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
267
|
+
backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
|
|
268
|
+
Const.BACKWARD)
|
|
224
269
|
except Exception as e:
|
|
225
270
|
logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
|
|
226
271
|
f"Detailed exception information: {e}.")
|
|
227
|
-
|
|
228
|
-
|
|
272
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
|
|
273
|
+
result=None, err_msg=f"{e}")
|
|
274
|
+
return process_result_packet
|
|
229
275
|
|
|
276
|
+
process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
|
|
277
|
+
result=backward_output_list, err_msg="")
|
|
278
|
+
return process_result_packet
|
|
230
279
|
|
|
231
280
|
def run_and_compare(self):
|
|
232
281
|
for api_name_str, api_info in tqdm(self.api_infos.items()):
|
|
@@ -234,14 +283,17 @@ class ApiAccuracyChecker:
|
|
|
234
283
|
continue
|
|
235
284
|
|
|
236
285
|
# 处理前向
|
|
237
|
-
|
|
238
|
-
if
|
|
239
|
-
self.data_manager.record(
|
|
286
|
+
process_result_packet = self.process_forward(api_name_str, api_info)
|
|
287
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
288
|
+
self.data_manager.record(process_result_packet.result)
|
|
289
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
290
|
+
self.data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
|
|
240
291
|
|
|
241
292
|
# 处理反向
|
|
242
|
-
|
|
243
|
-
if
|
|
244
|
-
self.data_manager.record(
|
|
293
|
+
process_result_packet = self.process_backward(api_name_str, api_info)
|
|
294
|
+
if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
|
|
295
|
+
self.data_manager.record(process_result_packet.result)
|
|
296
|
+
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
297
|
+
self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
|
|
245
298
|
|
|
246
299
|
self.data_manager.save_results(api_name_str)
|
|
247
|
-
|