mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# 强化学习数据采集
|
|
2
|
+
|
|
3
|
+
## 介绍
|
|
4
|
+
在强化学习训练过程中,往往存在多个模型(actor、reward、reference)和两个阶段(推理、训练),问题定界困难。
|
|
5
|
+
|
|
6
|
+
本工具提供一种灵活存储强化学习训练过程中关键阶段性数据的能力,并支持对比两次采集的关键数据,以支持问题快速定界。
|
|
7
|
+
|
|
8
|
+
常用关键数据示例:prompt、response、reward、log_prob、ref_log_probe、old_log_probe、kl_loss。
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
## 安装教程
|
|
12
|
+
|
|
13
|
+
参见 msprobe [安装教程](./01.installation.md)。
|
|
14
|
+
|
|
15
|
+
## 使用说明
|
|
16
|
+
|
|
17
|
+
### 数据采集
|
|
18
|
+
|
|
19
|
+
用户识别脚本中需要采集数据的地方,然后通过插入代码的方式采集关键数据。
|
|
20
|
+
|
|
21
|
+
当确定需要采集数据的地方,例如response,可以按如下方式对数据进行存储:
|
|
22
|
+
```
|
|
23
|
+
from msprobe.core import SingleSave
|
|
24
|
+
SingleSave("./dump_path", fmk="pytorch")
|
|
25
|
+
SingleSave.save({"response": response})
|
|
26
|
+
```
|
|
27
|
+
其中"./dump_path"为输出路径,没有默认值,需要自己配置;fmk可选"pytorch"或者"mindspore",默认"pytorch"。
|
|
28
|
+
|
|
29
|
+
其中"response"是可以任意指定的key,response是训练过程中的真实tensor变量。
|
|
30
|
+
|
|
31
|
+
也支持一次性存储多个数据:
|
|
32
|
+
```
|
|
33
|
+
from msprobe.core import SingleSave
|
|
34
|
+
SingleSave("./dump_path", fmk="pytorch")
|
|
35
|
+
SingleSave.save({
|
|
36
|
+
"prompt": prompt,
|
|
37
|
+
"response": response
|
|
38
|
+
})
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
### 配置保存
|
|
42
|
+
|
|
43
|
+
当确定需要采集数据配置json的地方,可以按如下方式对配置进行存储:
|
|
44
|
+
```
|
|
45
|
+
from msprobe.core import SingleSave
|
|
46
|
+
SingleSave("./dump_path")
|
|
47
|
+
SingleSave.save_config(configurations_json)
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
采集到的数据目录结构如下:
|
|
52
|
+
```txt
|
|
53
|
+
dump_path/
|
|
54
|
+
├── data/ # 固定为data
|
|
55
|
+
│ └── response/ # 关键数据名称,来自SingleSave.save的时候的key
|
|
56
|
+
│ └── step0/ # step数
|
|
57
|
+
│ └── rank0/ # rank数
|
|
58
|
+
│ └── micro_step0/ #micro_step数
|
|
59
|
+
| └── response0.npy #存储的关键数据的真实npy文件
|
|
60
|
+
| └── response0.json #存储的关键数据的统计量文件,包括tensor的最大、最小、均值、norm、shape
|
|
61
|
+
├── configurations.json # 配置json文件
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### 结果比对
|
|
65
|
+
|
|
66
|
+
两次采集数据之后得到dump_path1和dump_path2,可以创建一个比对脚本,例如compare.py,将两次训练的dump_path传入:
|
|
67
|
+
```
|
|
68
|
+
from msprobe.core import SingleComparator
|
|
69
|
+
SingleComparator.compare(
|
|
70
|
+
"dump_path1",
|
|
71
|
+
"dump_path2",
|
|
72
|
+
"output_path")
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
会在output_path下对每种关键数据都生成excel结果表格,比如response.xlsx,形式为关键数据的名字加上.xlsx后缀。
|
|
76
|
+
|
|
77
|
+
表格会体现每一个对应tensor的差异,解释:
|
|
78
|
+
|
|
79
|
+
表头 | 解释 |
|
|
80
|
+
|-------|---------|
|
|
81
|
+
| step | 训练步数 |
|
|
82
|
+
| rank | 卡号 |
|
|
83
|
+
| micro_step | 梯度累计步数 |
|
|
84
|
+
| id | 参数的shape |
|
|
85
|
+
| shape1 | dump_path1中的数据形状 |
|
|
86
|
+
| shape2 | dump_path2中的数据形状 |
|
|
87
|
+
| 相同元素百分比 | 元素相同的个数占总元素个数的百分比 |
|
|
88
|
+
| 首个不匹配元素索引 | 首个匹配不上的元素是第几个 |
|
|
89
|
+
| 最大绝对误差 | 最大绝对误差 |
|
|
90
|
+
| 最大相对误差 | 最大相对误差 |
|
|
91
|
+
| 误差在千分之一内元素占比 | 误差在千分之一内元素个数占总元素个数的百分比 |
|
|
92
|
+
| 误差在百分之一内元素占比 | 误差在百分之一内元素个数占总元素个数的百分比 |
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# 整网首个溢出节点分析
|
|
2
|
+
|
|
3
|
+
## 介绍
|
|
4
|
+
在分析inf、nan的场景下,会采集多个rank下的多个step的dump数据,前面出现的异常会传播到同rank后续的节点,并通过通信算子传播到其他rank的后续节点中,因此如何分析首个nan出现的节点位置尤为重要。
|
|
5
|
+
|
|
6
|
+
通过nan_analyze工具可以对pytorch的dump数据进行分析。在多卡场景下,检测到每张卡中产生inf/nan的节点。若是经过通信导致的inf/nan,可以分析并找出首个产生inf/nan的rank和节点。
|
|
7
|
+
|
|
8
|
+
## 安装教程
|
|
9
|
+
|
|
10
|
+
参见 msprobe [安装教程](./01.installation.md)。
|
|
11
|
+
|
|
12
|
+
## 使用说明
|
|
13
|
+
|
|
14
|
+
当前仅支持分析pytorch的dump数据。
|
|
15
|
+
|
|
16
|
+
### 采集数据
|
|
17
|
+
|
|
18
|
+
参见 [PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)。
|
|
19
|
+
|
|
20
|
+
### 执行命令
|
|
21
|
+
|
|
22
|
+
```commandline
|
|
23
|
+
msprobe -f pytorch nan_analyze -i dump_step_path -o output_dir_path
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
| 参数 | 说明 |
|
|
27
|
+
|--------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|
28
|
+
| -i 或 --input_path | dump数据的目录。需指定到step层级,如`-i /xxx/dump/step0/` |
|
|
29
|
+
| -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 |
|
|
30
|
+
|
|
31
|
+
### 输出文件介绍
|
|
32
|
+
|
|
33
|
+
当日志打印
|
|
34
|
+
```
|
|
35
|
+
Cannot find any anomaly node, no need to generate analyze file.
|
|
36
|
+
```
|
|
37
|
+
时,分析认为不存在异常节点,不生成分析文件。
|
|
38
|
+
|
|
39
|
+
存在异常节点时,生成`anomaly_analyze_{timestamp}.json`文件,结构为:
|
|
40
|
+
```json
|
|
41
|
+
{
|
|
42
|
+
"rank_0": [ // 卡号
|
|
43
|
+
{
|
|
44
|
+
"op_name": "Tensor.op_name.0.forward", // 节点名
|
|
45
|
+
"data_info": {
|
|
46
|
+
"input_args": [], // input_args数据
|
|
47
|
+
"input_kwargs": {}, // input_kwargs数据
|
|
48
|
+
"output": [] // output数据
|
|
49
|
+
},
|
|
50
|
+
"construct_info": [], // 节点层级数据
|
|
51
|
+
"stack_info": {} // 堆栈数据
|
|
52
|
+
}
|
|
53
|
+
]
|
|
54
|
+
}
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
## 异常判定
|
|
58
|
+
|
|
59
|
+
### 异常计算节点判定
|
|
60
|
+
当某个计算节点的输入值正常,即Max或Min中不存在inf或nan,而输出值存在异常时认为从此节点开始产生了溢出,并有可能向后传递。
|
|
61
|
+
|
|
62
|
+
### 异常通信节点判定
|
|
63
|
+
通信节点按照功能分为有向节点,如`send`, `recv`, `scatter`, `gather`, `broadcast`, `reduce`等,以及无向节点,如`all_gather`, `all_reduce`, `reduce_scatter`, `all_to_all`等。
|
|
64
|
+
|
|
65
|
+
对于有向节点,当src节点的input存在异常时,通常认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当src节点的input正常而output存在异常值,或dst节点的output存在异常值时,考虑是通信节点本身的操作产生了异常数据。
|
|
66
|
+
|
|
67
|
+
对于无向节点,当节点input存在异常时,认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当input正常而output异常时,考虑是通信节点本身的操作产生了异常数据。
|
|
68
|
+
|
|
69
|
+
### 顺序判定
|
|
70
|
+
对于相连接的有向通信算子,认为src节点的异常发生早于dst节点;对于无向通信算子,认为异常是同时发生的。
|
|
71
|
+
|
|
72
|
+
对于计算节点按照dump的顺序排序。
|
msprobe/docs/FAQ.md
CHANGED
|
@@ -58,11 +58,7 @@
|
|
|
58
58
|
|
|
59
59
|
答:对于 fp16 的数据,CPU 会上升一个精度 fp32 去计算,这是和算子那边对齐的精度结论,CPU 用更高精度去计算会更接近真实值。
|
|
60
60
|
|
|
61
|
-
6.
|
|
62
|
-
|
|
63
|
-
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 Tensor: 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要 dump 关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
64
|
-
|
|
65
|
-
7. Tensor 魔法函数具体对应什么操作?
|
|
61
|
+
6. Tensor 魔法函数具体对应什么操作?
|
|
66
62
|
|
|
67
63
|
答:
|
|
68
64
|
|
|
@@ -202,15 +198,11 @@ def npu_forward_fused_softmax(self, input_, mask):
|
|
|
202
198
|
|
|
203
199
|
答:正常现象,dataloader 通过 raise 结束程序,堆栈信息可忽略。
|
|
204
200
|
|
|
205
|
-
10.
|
|
206
|
-
|
|
207
|
-
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
208
|
-
|
|
209
|
-
11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
|
|
201
|
+
10. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。
|
|
210
202
|
|
|
211
203
|
答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。
|
|
212
204
|
规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。
|
|
213
205
|
|
|
214
|
-
|
|
206
|
+
11. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。
|
|
215
207
|
|
|
216
208
|
答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: `下的 `-t` 和 `- transpose`。
|
|
@@ -1,6 +1,17 @@
|
|
|
1
1
|
# MindSpore 场景的精度数据采集基线
|
|
2
2
|
|
|
3
|
-
## "
|
|
3
|
+
## "statistics"模式(未开启md5)采集**时间**膨胀参考基线
|
|
4
|
+
|
|
5
|
+
该基线为MindSpore框架下,使用"statistics"模式采集数据性能膨胀参考基线。测试了38B语言大模型在不同采集模式8卡下的性能膨胀。
|
|
6
|
+
|
|
7
|
+
| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) |
|
|
8
|
+
|:--------:|:-------------:|:--------------------:|:----------------:|
|
|
9
|
+
| L0 | ≈340 ms | ≈340 ms (无膨胀) | ≈1.2 s (膨胀3.5倍) |
|
|
10
|
+
| L1 | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈3.8 s (膨胀11倍) |
|
|
11
|
+
| mix | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈5.5 s (膨胀16倍) |
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
## "tensor"模式采集**数据量**参考基线
|
|
4
15
|
|
|
5
16
|
该基线为MindSpore框架下,使用"tensor"模式采集数据量参考基线。本基线测试了38B语言大模型在不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。
|
|
6
17
|
|
|
@@ -51,6 +51,7 @@ debugger = PrecisionDebugger(config_path=config_path)
|
|
|
51
51
|
|
|
52
52
|
# 设置 MindSpore 设备上下文
|
|
53
53
|
context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0)
|
|
54
|
+
print("Context set successfully. Please wait for the training task.")
|
|
54
55
|
|
|
55
56
|
# 定义卷积层
|
|
56
57
|
def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True):
|
|
@@ -199,7 +200,7 @@ python alexnet_model.py
|
|
|
199
200
|
|
|
200
201
|
## 5. 数据分析
|
|
201
202
|
|
|
202
|
-
在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)
|
|
203
|
+
在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。
|
|
203
204
|
|
|
204
205
|
```bash
|
|
205
206
|
output/
|
|
@@ -208,4 +209,5 @@ output/
|
|
|
208
209
|
├── construct.json # level为L0时,保存Cell的层级关系信息。当前场景为空
|
|
209
210
|
├── dump.json # 保存API前反向输入输出数据的统计量信息
|
|
210
211
|
└── stack.json # 保存API的调用栈
|
|
212
|
+
......
|
|
211
213
|
```
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -17,12 +17,12 @@ import os
|
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
from msprobe.lib import _msprobe_c
|
|
20
|
-
os.environ["MS_HOOK_ENABLE"] = "on"
|
|
21
20
|
os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__
|
|
22
21
|
except ImportError:
|
|
23
22
|
from .common.log import logger
|
|
24
23
|
logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.")
|
|
25
24
|
|
|
26
25
|
from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
|
|
27
|
-
from msprobe.mindspore.common.utils import seed_all
|
|
28
|
-
from msprobe.mindspore.monitor.module_hook import TrainerMon
|
|
26
|
+
from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep
|
|
27
|
+
from msprobe.mindspore.monitor.module_hook import TrainerMon
|
|
28
|
+
from msprobe.mindspore.dump.graph_tensor_dump import save, save_grad
|
|
@@ -14,9 +14,11 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, Optional
|
|
17
19
|
from tqdm import tqdm
|
|
18
|
-
|
|
19
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
20
|
+
import numpy as np
|
|
21
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
20
22
|
from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
|
|
21
23
|
from msprobe.core.common.utils import add_time_as_suffix
|
|
22
24
|
from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
|
|
@@ -25,8 +27,12 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar
|
|
|
25
27
|
from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
26
28
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
27
29
|
trim_output_compute_element_list)
|
|
30
|
+
from msprobe.mindspore.common.const import MsCompareConst
|
|
28
31
|
from msprobe.mindspore.common.log import logger
|
|
29
32
|
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
33
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
34
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
|
|
35
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
30
36
|
|
|
31
37
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
32
38
|
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
@@ -58,13 +64,128 @@ class ProcessResultPacket:
|
|
|
58
64
|
self.err_msg = err_msg
|
|
59
65
|
|
|
60
66
|
|
|
67
|
+
@dataclass
|
|
68
|
+
class Config:
|
|
69
|
+
execution_mode: str
|
|
70
|
+
dump_path: str
|
|
71
|
+
task: str
|
|
72
|
+
level: str
|
|
73
|
+
scope: Optional[Any]
|
|
74
|
+
list: Optional[Any]
|
|
75
|
+
framework: str
|
|
76
|
+
data_mode: str
|
|
77
|
+
file_format: str
|
|
78
|
+
dump_tensor_data_dir: str
|
|
79
|
+
async_dump: bool
|
|
80
|
+
summary_mode: Optional[Any] = None
|
|
81
|
+
|
|
82
|
+
|
|
61
83
|
class ApiAccuracyChecker:
|
|
62
84
|
def __init__(self, args):
|
|
63
85
|
self.api_infos = dict()
|
|
64
86
|
self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
|
|
87
|
+
self.save_error_data = args.save_error_data
|
|
88
|
+
if self.save_error_data:
|
|
89
|
+
config, dump_path_aggregation = self.init_save_error_data(args)
|
|
90
|
+
self.data_collector = build_data_collector(config)
|
|
91
|
+
self.data_collector.update_dump_paths(dump_path_aggregation)
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def init_save_error_data(args):
|
|
95
|
+
config = Config(
|
|
96
|
+
execution_mode="pynative",
|
|
97
|
+
dump_path=f"{args.out_path}",
|
|
98
|
+
dump_tensor_data_dir=f"{args.out_path}",
|
|
99
|
+
task="tensor", # 任务类型,模拟保存tensor数据
|
|
100
|
+
level="L1", # 级别
|
|
101
|
+
scope=None, # 作用域 (None)
|
|
102
|
+
list=None, # API 列表 (None)
|
|
103
|
+
framework=Const.MS_FRAMEWORK, # 框架类型
|
|
104
|
+
data_mode="all",
|
|
105
|
+
file_format="npy",
|
|
106
|
+
async_dump=False
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
dump_dir = f"{args.out_path}"
|
|
110
|
+
dump_data_dir = os.path.join(dump_dir, "error_data")
|
|
111
|
+
create_directory(dump_data_dir)
|
|
112
|
+
dump_path_aggregation = DumpPathAggregation()
|
|
113
|
+
dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
114
|
+
dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
115
|
+
dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
|
|
116
|
+
return config, dump_path_aggregation
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
120
|
+
"""
|
|
121
|
+
Args:
|
|
122
|
+
api_info: ApiInfo
|
|
123
|
+
forward_or_backward: str
|
|
124
|
+
Returns:
|
|
125
|
+
ApiInputAggregation
|
|
126
|
+
"""
|
|
127
|
+
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
128
|
+
kwargs = api_info.get_kwargs()
|
|
129
|
+
if forward_or_backward == Const.FORWARD:
|
|
130
|
+
gradient_inputs = None
|
|
131
|
+
else:
|
|
132
|
+
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
133
|
+
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
65
134
|
|
|
66
135
|
@staticmethod
|
|
67
|
-
def
|
|
136
|
+
def is_api_checkable(api_name_str):
|
|
137
|
+
'''
|
|
138
|
+
Args:
|
|
139
|
+
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
140
|
+
Returns:
|
|
141
|
+
is_checkable: bool
|
|
142
|
+
Description:
|
|
143
|
+
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
144
|
+
'''
|
|
145
|
+
api_name_str_list = api_name_str.split(Const.SEP)
|
|
146
|
+
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
147
|
+
return False
|
|
148
|
+
api_type_str = api_name_str_list[0]
|
|
149
|
+
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
150
|
+
api_list = load_yaml(yaml_path)
|
|
151
|
+
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
152
|
+
supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST
|
|
153
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
154
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
155
|
+
return True
|
|
156
|
+
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
157
|
+
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
158
|
+
return True
|
|
159
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
160
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
161
|
+
return True
|
|
162
|
+
if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \
|
|
163
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
164
|
+
return True
|
|
165
|
+
return False
|
|
166
|
+
|
|
167
|
+
def post_forward_hook(self, api_or_module_name, primitive_instance, args, kwargs, output):
|
|
168
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
169
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
170
|
+
self.data_collector.forward_data_collect_only_tensor(
|
|
171
|
+
api_or_module_name,
|
|
172
|
+
primitive_instance,
|
|
173
|
+
os.getpid(),
|
|
174
|
+
module_input_output
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def backward_hook(self, api_or_module_name, module, grad_input, grad_output):
|
|
178
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
179
|
+
|
|
180
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
181
|
+
self.data_collector.backward_data_collect_only_tensor(
|
|
182
|
+
api_or_module_name,
|
|
183
|
+
module,
|
|
184
|
+
os.getpid(),
|
|
185
|
+
module_input_output
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def run_and_compare_helper(self, api_info, api_name_str, api_input_aggregation, forward_or_backward):
|
|
68
189
|
"""
|
|
69
190
|
Args:
|
|
70
191
|
api_info: ApiInfo
|
|
@@ -82,13 +203,22 @@ class ApiAccuracyChecker:
|
|
|
82
203
|
"""
|
|
83
204
|
# get output
|
|
84
205
|
if global_context.get_is_constructed():
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
206
|
+
if forward_or_backward == Const.FORWARD:
|
|
207
|
+
tested_outputs, inputs, kwargs, forward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
208
|
+
forward_or_backward,
|
|
209
|
+
global_context.get_framework())
|
|
210
|
+
elif forward_or_backward == Const.BACKWARD:
|
|
211
|
+
tested_outputs, gradient_inputs, backward_result_tuple = api_runner(api_input_aggregation, api_name_str,
|
|
212
|
+
forward_or_backward,
|
|
213
|
+
global_context.get_framework())
|
|
214
|
+
else:
|
|
215
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
216
|
+
forward_or_backward, global_context.get_framework())
|
|
88
217
|
else:
|
|
89
218
|
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
90
219
|
|
|
91
220
|
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
221
|
+
|
|
92
222
|
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
93
223
|
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
94
224
|
if len(tested_outputs) != len(bench_outputs):
|
|
@@ -113,60 +243,26 @@ class ApiAccuracyChecker:
|
|
|
113
243
|
compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
|
|
114
244
|
status = CompareConst.PASS
|
|
115
245
|
err_msg = ""
|
|
246
|
+
|
|
116
247
|
else:
|
|
117
248
|
status = CompareConst.ERROR
|
|
118
249
|
err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
|
|
119
250
|
compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
|
|
251
|
+
if forward_or_backward == Const.FORWARD and self.save_error_data \
|
|
252
|
+
and global_context.get_is_constructed():
|
|
253
|
+
api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.FORWARD}"
|
|
254
|
+
self.post_forward_hook(api_name_str_backward, None, inputs, kwargs, forward_result_tuple)
|
|
255
|
+
|
|
256
|
+
if forward_or_backward == Const.BACKWARD and self.save_error_data \
|
|
257
|
+
and global_context.get_is_constructed():
|
|
258
|
+
api_name_str_backward = f"{api_name_str}{Const.SEP}{Const.BACKWARD}"
|
|
259
|
+
self.backward_hook(api_name_str_backward, None, gradient_inputs, backward_result_tuple)
|
|
260
|
+
|
|
120
261
|
basic_info_status = \
|
|
121
262
|
BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
|
|
122
263
|
output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
|
|
123
264
|
return output_list
|
|
124
265
|
|
|
125
|
-
@staticmethod
|
|
126
|
-
def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
|
|
127
|
-
"""
|
|
128
|
-
Args:
|
|
129
|
-
api_info: ApiInfo
|
|
130
|
-
forward_or_backward: str
|
|
131
|
-
Returns:
|
|
132
|
-
ApiInputAggregation
|
|
133
|
-
"""
|
|
134
|
-
forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
|
|
135
|
-
kwargs = api_info.get_kwargs()
|
|
136
|
-
if forward_or_backward == Const.FORWARD:
|
|
137
|
-
gradient_inputs = None
|
|
138
|
-
else:
|
|
139
|
-
gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
|
|
140
|
-
return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
|
|
141
|
-
|
|
142
|
-
@staticmethod
|
|
143
|
-
def is_api_checkable(api_name_str):
|
|
144
|
-
'''
|
|
145
|
-
Args:
|
|
146
|
-
api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
|
|
147
|
-
Returns:
|
|
148
|
-
is_checkable: bool
|
|
149
|
-
Description:
|
|
150
|
-
tell whether this api is checkable based on the key in "data" dict in api_info.json
|
|
151
|
-
'''
|
|
152
|
-
api_name_str_list = api_name_str.split(Const.SEP)
|
|
153
|
-
if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
|
|
154
|
-
return False
|
|
155
|
-
api_type_str = api_name_str_list[0]
|
|
156
|
-
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
157
|
-
api_list = load_yaml(yaml_path)
|
|
158
|
-
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
159
|
-
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
160
|
-
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
161
|
-
return True
|
|
162
|
-
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
163
|
-
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
164
|
-
return True
|
|
165
|
-
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
166
|
-
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
167
|
-
return True
|
|
168
|
-
return False
|
|
169
|
-
|
|
170
266
|
def parse(self, api_info_path):
|
|
171
267
|
|
|
172
268
|
api_info_dict = load_json(api_info_path)
|
|
@@ -178,9 +274,9 @@ class ApiAccuracyChecker:
|
|
|
178
274
|
MsCompareConst.TENSOR_TASK))
|
|
179
275
|
try:
|
|
180
276
|
framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
277
|
+
"framework field in api_info.json", accepted_type=str,
|
|
278
|
+
accepted_value=(Const.MS_FRAMEWORK,
|
|
279
|
+
Const.MT_FRAMEWORK))
|
|
184
280
|
except Exception as e:
|
|
185
281
|
framework = Const.MS_FRAMEWORK
|
|
186
282
|
logger.warning(f"JSON parsing error in framework field: {e}")
|
|
@@ -296,4 +392,4 @@ class ApiAccuracyChecker:
|
|
|
296
392
|
elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
|
|
297
393
|
self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
|
|
298
394
|
|
|
299
|
-
self.data_manager.save_results(api_name_str)
|
|
395
|
+
self.data_manager.save_results(api_name_str)
|