mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。
|
|
4
4
|
|
|
5
|
-
dump
|
|
5
|
+
dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](./26.data_dump_PyTorch_baseline.md)。
|
|
6
6
|
|
|
7
7
|
本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例:
|
|
8
8
|
|
|
@@ -15,6 +15,52 @@ functional: # functional为算子类别,找到对应的类别,在该类别
|
|
|
15
15
|
|
|
16
16
|
删除API的场景:部分模型代码逻辑会存在API原生类型校验,工具执行dump操作时,对模型的API封装可能与模型的原生API类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md)》中“异常情况”的第10和11条。
|
|
17
17
|
|
|
18
|
+
## 快速上手
|
|
19
|
+
|
|
20
|
+
这个示例定义了一个 nn.Module 类型的简单网络,使用原型函数 PrecisionDebugger 进行数据采集。
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
# 根据需要import包
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
import torch.nn.functional as F
|
|
27
|
+
|
|
28
|
+
# 导入工具的数据采集接口
|
|
29
|
+
from msprobe.pytorch import PrecisionDebugger, seed_all
|
|
30
|
+
|
|
31
|
+
# 在模型训练开始前固定随机性
|
|
32
|
+
seed_all()
|
|
33
|
+
|
|
34
|
+
# 在模型训练开始前实例化PrecisionDebugger
|
|
35
|
+
debugger = PrecisionDebugger()
|
|
36
|
+
|
|
37
|
+
# 定义网络
|
|
38
|
+
class ModuleOP(nn.Module):
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.linear_1 = nn.Linear(in_features=8, out_features=4)
|
|
42
|
+
self.linear_2 = nn.Linear(in_features=4, out_features=2)
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
x1 = self.linear_1(x)
|
|
46
|
+
x2 = self.linear_2(x1)
|
|
47
|
+
r1 = F.relu(x2)
|
|
48
|
+
return r1
|
|
49
|
+
|
|
50
|
+
if __name__ == "__main__":
|
|
51
|
+
module = ModuleOP()
|
|
52
|
+
|
|
53
|
+
# 开启数据 dump
|
|
54
|
+
debugger.start(model=module)
|
|
55
|
+
x = torch.randn(10, 8)
|
|
56
|
+
out = module(x)
|
|
57
|
+
loss = out.sum()
|
|
58
|
+
loss.backward()
|
|
59
|
+
|
|
60
|
+
# 关闭数据 dump
|
|
61
|
+
debugger.stop()
|
|
62
|
+
```
|
|
63
|
+
|
|
18
64
|
## 1 接口介绍
|
|
19
65
|
|
|
20
66
|
### 1.1 PrecisionDebugger
|
|
@@ -30,9 +76,11 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model
|
|
|
30
76
|
1. config_path:指定 dump 配置文件路径;
|
|
31
77
|
2. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module 或 list[torch.nn.Module] 类型,默认未配置。
|
|
32
78
|
level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置该参数。该参数在将来会从该接口移除,建议在 **start** 接口中配置该参数。
|
|
33
|
-
3. 其他参数均在
|
|
79
|
+
3. 其他参数均在 config.json 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。
|
|
34
80
|
|
|
35
|
-
|
|
81
|
+
此接口的参数均不是必要(均不配置的情况下默认采集所有 rank 和 step 的 L1 级别的统计数据),且优先级高于 config.json 文件中的配置,但可配置的参数相比 config.json 较少。
|
|
82
|
+
|
|
83
|
+
注:此接口的初始化需与采集目标在同一个进程中,否则将无法采集目标数据。
|
|
36
84
|
|
|
37
85
|
### 1.2 start
|
|
38
86
|
|
|
@@ -41,12 +89,15 @@ level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置
|
|
|
41
89
|
**原型**:
|
|
42
90
|
|
|
43
91
|
```Python
|
|
44
|
-
debugger.start(model=None)
|
|
92
|
+
debugger.start(model=None, token_range=None)
|
|
45
93
|
```
|
|
46
94
|
|
|
47
95
|
1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。
|
|
48
|
-
level 配置为"L0"
|
|
96
|
+
level 配置为"L0"|"mix"或token_range不为None时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。
|
|
49
97
|
本接口中的 model 比 PrecisionDebugger 中 model 参数优先级更高,会覆盖 PrecisionDebugger 中的 model 参数。
|
|
98
|
+
<br>对于复杂模型,如果仅需要监控一部分(如model.A,model.A extends torch.nn.Module),传入需要监控的部分(如model.A)即可。
|
|
99
|
+
注意:传入的当前层不会被dump,工具只会dump传入层的子层级。如传入了model.A,A本身不会被dump,而是会dump A.x, A.x.xx等。
|
|
100
|
+
2. token_range:指定推理模型采集时的token循环始末范围,支持传入[int, int]类型,代表[start, end],范围包含边界,默认未配置。
|
|
50
101
|
|
|
51
102
|
### 1.3 stop
|
|
52
103
|
|
|
@@ -183,58 +234,65 @@ save(variable, name, save_backward=True)
|
|
|
183
234
|
**参数说明**:
|
|
184
235
|
| 参数名称 | 参数含义 | 支持数据类型 | 是否必选|
|
|
185
236
|
| ---------- | ------------------| ------------------- | ------------------- |
|
|
186
|
-
| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 |
|
|
237
|
+
| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 |
|
|
187
238
|
| name | 指定的名称 | str | 是 |
|
|
188
239
|
| save_backward | 是否保存反向数据 | boolean | 否 |
|
|
189
240
|
|
|
190
|
-
|
|
241
|
+
### 1.10 set_init_step
|
|
191
242
|
|
|
192
|
-
|
|
243
|
+
**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需要写在训练迭代的循环开始前,不能写在循环内。
|
|
193
244
|
|
|
194
|
-
|
|
245
|
+
**原型**:
|
|
195
246
|
|
|
196
|
-
```
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
import torch.nn as nn
|
|
200
|
-
import torch.nn.functional as F
|
|
247
|
+
```Python
|
|
248
|
+
debugger.set_init_step(step)
|
|
249
|
+
```
|
|
201
250
|
|
|
202
|
-
|
|
203
|
-
from msprobe.pytorch import PrecisionDebugger, seed_all
|
|
251
|
+
**参数说明**:
|
|
204
252
|
|
|
205
|
-
|
|
206
|
-
seed_all()
|
|
207
|
-
# 在模型训练开始前实例化PrecisionDebugger
|
|
208
|
-
debugger = PrecisionDebugger(config_path='./config.json')
|
|
253
|
+
1.step: 指定的起始step数。
|
|
209
254
|
|
|
210
|
-
|
|
211
|
-
class ModuleOP(nn.Module):
|
|
212
|
-
def __init__(self) -> None:
|
|
213
|
-
super().__init__()
|
|
214
|
-
self.linear_1 = nn.Linear(in_features=8, out_features=4)
|
|
215
|
-
self.linear_2 = nn.Linear(in_features=4, out_features=2)
|
|
255
|
+
### 1.11 register_custom_api
|
|
216
256
|
|
|
217
|
-
|
|
218
|
-
x1 = self.linear_1(x)
|
|
219
|
-
x2 = self.linear_2(x1)
|
|
220
|
-
r1 = F.relu(x2)
|
|
221
|
-
return r1
|
|
257
|
+
**功能说明**:注册用户自定义的api到工具用于 L1 dump 。
|
|
222
258
|
|
|
223
|
-
|
|
224
|
-
module = ModuleOP()
|
|
225
|
-
# 开启数据 dump
|
|
226
|
-
debugger.start(model=module)
|
|
259
|
+
**原型**:
|
|
227
260
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
261
|
+
```Python
|
|
262
|
+
debugger.register_custom_api(module, api_name, api_prefix)
|
|
263
|
+
```
|
|
264
|
+
**参数说明**:
|
|
232
265
|
|
|
233
|
-
|
|
234
|
-
|
|
266
|
+
以 torch.matmul api 为例
|
|
267
|
+
|
|
268
|
+
1.module: api 所属的包,即传入 torch。
|
|
269
|
+
|
|
270
|
+
2.api_name: api 名,string类型,即传入 "matmul"。
|
|
271
|
+
|
|
272
|
+
3.api_prefix: [dump.json](./27.dump_json_instruction.md) 中 api 名的前缀,可选,默认为包名的字符串格式, 即 "torch"。
|
|
273
|
+
|
|
274
|
+
### 1.12 restore_custom_api
|
|
275
|
+
|
|
276
|
+
**功能说明**:恢复用户原有的自定义的api,取消 dump 。
|
|
277
|
+
|
|
278
|
+
**原型**:
|
|
279
|
+
|
|
280
|
+
```Python
|
|
281
|
+
debugger.restore_custom_api(module, api_name)
|
|
235
282
|
```
|
|
283
|
+
**参数说明**:
|
|
284
|
+
|
|
285
|
+
以 torch.matmul api 为例
|
|
286
|
+
|
|
287
|
+
1.module: api 所属的包,即传入 torch。
|
|
288
|
+
|
|
289
|
+
2.api_name: api 名,string类型,即传入 "matmul"。
|
|
236
290
|
|
|
237
|
-
|
|
291
|
+
|
|
292
|
+
## 2 示例代码
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
### 2.1 采集完整的前反向数据
|
|
238
296
|
|
|
239
297
|
```Python
|
|
240
298
|
from msprobe.pytorch import PrecisionDebugger, seed_all
|
|
@@ -255,7 +313,7 @@ for data, label in data_loader:
|
|
|
255
313
|
debugger.step() # 结束一个step的dump
|
|
256
314
|
```
|
|
257
315
|
|
|
258
|
-
### 2.
|
|
316
|
+
### 2.2 采集指定代码块的前反向数据
|
|
259
317
|
|
|
260
318
|
```Python
|
|
261
319
|
from msprobe.pytorch import PrecisionDebugger, seed_all
|
|
@@ -279,7 +337,7 @@ for data, label in data_loader:
|
|
|
279
337
|
debugger.step() # 结束一个step的dump
|
|
280
338
|
```
|
|
281
339
|
|
|
282
|
-
### 2.
|
|
340
|
+
### 2.3 采集函数模块化数据
|
|
283
341
|
|
|
284
342
|
```Python
|
|
285
343
|
# 根据需要import包
|
|
@@ -321,6 +379,80 @@ if __name__ == "__main__":
|
|
|
321
379
|
debugger.stop()
|
|
322
380
|
```
|
|
323
381
|
|
|
382
|
+
### 2.4 跨文件采集数据
|
|
383
|
+
为了确保所有API都被工具封装,PrecisionDebugger的实例化通常放在训练工程的入口位置,但有的时候,模型定义会在另一个文件中。 假设有两个文件,train.py(为训练工程入口)module.py(为模型定义文件),为了采集module.py中定义的ModuleOP模块中某些子模块或API的前反向数据,需要在train.py和module.py文件中分别导入PrecisionDebugger并进行如下配置。
|
|
384
|
+
|
|
385
|
+
train.py文件:
|
|
386
|
+
|
|
387
|
+
```Python
|
|
388
|
+
# 根据需要import包
|
|
389
|
+
import torch
|
|
390
|
+
from module import ModuleOP
|
|
391
|
+
|
|
392
|
+
# 导入工具的数据采集接口
|
|
393
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
394
|
+
|
|
395
|
+
# 将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装
|
|
396
|
+
debugger = PrecisionDebugger(config_path='./config.json')
|
|
397
|
+
|
|
398
|
+
if __name__ == "__main__":
|
|
399
|
+
module = ModuleOP()
|
|
400
|
+
|
|
401
|
+
x = torch.randn(10, 8)
|
|
402
|
+
out = module(x)
|
|
403
|
+
loss = out.sum()
|
|
404
|
+
loss.backward()
|
|
405
|
+
```
|
|
406
|
+
|
|
407
|
+
module.py文件:
|
|
408
|
+
|
|
409
|
+
```Python
|
|
410
|
+
import torch
|
|
411
|
+
import torch.nn as nn
|
|
412
|
+
import torch.nn.functional as F
|
|
413
|
+
|
|
414
|
+
from msprobe.pytorch import PrecisionDebugger
|
|
415
|
+
|
|
416
|
+
# 定义网络
|
|
417
|
+
class ModuleOP(nn.Module):
|
|
418
|
+
def __init__(self) -> None:
|
|
419
|
+
super().__init__()
|
|
420
|
+
self.linear_1 = nn.Linear(in_features=8, out_features=4)
|
|
421
|
+
self.linear_2 = nn.Linear(in_features=4, out_features=2)
|
|
422
|
+
|
|
423
|
+
def forward(self, x):
|
|
424
|
+
PrecisionDebugger.start()
|
|
425
|
+
x1 = self.linear_1(x)
|
|
426
|
+
PrecisionDebugger.stop()
|
|
427
|
+
x2 = self.linear_2(x1)
|
|
428
|
+
r1 = F.relu(x2)
|
|
429
|
+
return r1
|
|
430
|
+
|
|
431
|
+
```
|
|
432
|
+
|
|
433
|
+
### 2.5 推理模型采集指定token_range
|
|
434
|
+
|
|
435
|
+
```Python
|
|
436
|
+
from vllm import LLM, SamplingParams
|
|
437
|
+
from msprobe.pytorch import PrecisionDebugger, seed_all
|
|
438
|
+
# 在模型训练开始前固定随机性
|
|
439
|
+
seed_all()
|
|
440
|
+
# 请勿将PrecisionDebugger的初始化流程插入到循环代码中
|
|
441
|
+
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
|
|
442
|
+
# 模型定义及初始化等操作
|
|
443
|
+
prompts = ["Hello, my name is"]
|
|
444
|
+
sampling_params = SamplingParams(temprature=0.8, top_p=0.95)
|
|
445
|
+
llm = LLM(model='...')
|
|
446
|
+
model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
|
|
447
|
+
# 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次
|
|
448
|
+
debugger.start(model=model, token_range=[1,3])
|
|
449
|
+
# 推理模型生成的逻辑
|
|
450
|
+
output = llm.generate(prompts, sampling_params=sampling_params)
|
|
451
|
+
# 关闭数据dump并落盘
|
|
452
|
+
debugger.stop()
|
|
453
|
+
debugger.step()
|
|
454
|
+
```
|
|
455
|
+
|
|
324
456
|
## 3 dump 结果文件介绍
|
|
325
457
|
|
|
326
458
|
训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下:
|
|
@@ -334,8 +466,8 @@ if __name__ == "__main__":
|
|
|
334
466
|
| | | | ├── Functional.linear.5.backward.output.pt # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。
|
|
335
467
|
| | | | ...
|
|
336
468
|
| | | | ├── Module.conv1.Conv2d.forward.0.input.0.pt # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。
|
|
337
|
-
| | | | ├── Module.conv1.
|
|
338
|
-
| | | | └── Module.conv1.
|
|
469
|
+
| | | | ├── Module.conv1.Conv2d.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。
|
|
470
|
+
| | | | └── Module.conv1.Conv2d.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。
|
|
339
471
|
| | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.pt。
|
|
340
472
|
│ | | ├── dump.json
|
|
341
473
|
│ | | ├── stack.json
|
|
@@ -355,7 +487,7 @@ if __name__ == "__main__":
|
|
|
355
487
|
```
|
|
356
488
|
* `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。
|
|
357
489
|
* `dump_tensor_data`:保存采集到的张量数据。
|
|
358
|
-
* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-
|
|
490
|
+
* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-PyTorch场景下的dump.json文件)。
|
|
359
491
|
* `stack.json`:API/Module的调用栈信息。
|
|
360
492
|
* `construct.json`:分层分级结构,level为L1时,construct.json内容为空。
|
|
361
493
|
|
|
@@ -366,12 +498,14 @@ dump 过程中,pt 文件在对应算子或者模块被执行后就会落盘,
|
|
|
366
498
|
|
|
367
499
|
pt 文件保存的前缀和 PyTorch 对应关系如下:
|
|
368
500
|
|
|
369
|
-
| 前缀
|
|
370
|
-
|
|
501
|
+
| 前缀 | Torch模块 |
|
|
502
|
+
|-------------|---------------------|
|
|
371
503
|
| Tensor | torch.Tensor |
|
|
372
504
|
| Torch | torch |
|
|
373
505
|
| Functional | torch.nn.functional |
|
|
374
|
-
| NPU | NPU 亲和算子
|
|
506
|
+
| NPU | NPU 亲和算子 |
|
|
375
507
|
| VF | torch._VF |
|
|
376
508
|
| Aten | torch.ops.aten |
|
|
377
509
|
| Distributed | torch.distributed |
|
|
510
|
+
| MindSpeed | mindspeed.ops |
|
|
511
|
+
|