mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
* **验证低精度可疑 API**,确认升精度后是否对模型 Loss 有影响。
|
|
14
14
|
* 该工具的约束
|
|
15
15
|
* 仅支持 MindSpore 动态图场景。支持的 API 类型为 ops、Tensor、mint 和 mint.nn.functional 类的非 inplace 计算 API,不支持 Primitive 和 Jit 类 API。
|
|
16
|
+
* 仅支持 输入输出向量为浮点数类型(BF16、FP16、FP32、FP64)的 API 比对。
|
|
16
17
|
* 建议配置白名单(设置 list),控制对少量 API 进行无标杆比对。比对 API 越多,性能和显存损耗越大。
|
|
17
18
|
|
|
18
19
|
## 2 工具实现原理
|
|
@@ -79,40 +80,57 @@ D-->config.json配置
|
|
|
79
80
|
用户需根据自己的使用场景,对照[工具实现原理](#2-工具实现原理)中几个关键步骤进行配置
|
|
80
81
|
#### 3.2.1 确定比对范围
|
|
81
82
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
<table>
|
|
84
|
+
<tr><th>参数</th><th>是否必选</th><th>可配置项</th><th>适用场景</th></tr>
|
|
85
|
+
<tr><td>list</td><td>否</td><td>自定义</td><td>需要通过指定 API 名来限制比对API个数 如:["mindspore.ops.bmm"] 会只对mindspore.ops.bmm API进行比对。</td></tr>
|
|
86
|
+
<tr><td rowspan="2">fuzz_stage</td><td rowspan="2">否</td><td>"forward"(默认)</td><td>需要进行 API <b>前向</b>计算的精度问题排查或验证。</td></tr>
|
|
87
|
+
<tr><td>"backward"</td><td>需要进行 API <b>反向</b>计算的精度问题排查,不支持反向验证(前向验证包括反向)。</td></tr>
|
|
88
|
+
</table>
|
|
86
89
|
|
|
87
|
-
#### 3.2.2
|
|
90
|
+
#### 3.2.2 选择扰动因子
|
|
88
91
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
92
|
+
<table>
|
|
93
|
+
<tr><th>参数</th><th>是否必选</th><th>可配置项</th><th>适用场景</th></tr>
|
|
94
|
+
<tr><td rowspan="5">pert_mode</td><td rowspan="5">否</td><td>"improve_precision" (默认)</td><td>(常用)(可做验证) API 可能在<b>低精度</b>下有精度问题,扰动因子会将输入的低精度向量升精度。</td></tr>
|
|
95
|
+
<tr><td>"add_noise"</td><td>API 可能在<b>轻微扰动</b>下暴露精度问题,扰动因子会为输入向量增加一个极小值。</td></tr>
|
|
96
|
+
<tr><td>"bit_noise"</td><td>API 可能在<b>轻微扰动</b>下暴露精度问题,扰动因子会翻转输入向量的最后一个比特位。不支持BF16类型向量。</td></tr>
|
|
97
|
+
<tr><td>"no_change"</td><td>API 可能存在<b>数值稳定性</b>精度问题,扰动因子会复制原始输入。</td></tr>
|
|
98
|
+
<tr><td>"change_value"</td><td>API 可能存在<b>大数吃小数</b>问题,扰动因子会交换输入向量的首尾值。</td></tr>
|
|
99
|
+
</table>
|
|
96
100
|
|
|
97
|
-
#### 3.2.3
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
101
|
+
#### 3.2.3 选择处理方式
|
|
102
|
+
|
|
103
|
+
<table>
|
|
104
|
+
<tr><th>参数</th><th>是否必选</th><th>可配置项</th><th>适用场景</th></tr>
|
|
105
|
+
<tr><td rowspan="2">handler_type</td><td rowspan="2">否</td><td>"check"(默认)</td><td>要做精度问题 API 排查,输出扰动前后不符合精度标准的 API,支持所有扰动因子。</td></tr>
|
|
106
|
+
<tr><td>"fix"</td><td>要做可疑 API 验证,用扰动后输出替换原始输出,仅支持 "improve_precision" 扰动因子。</td></tr>
|
|
107
|
+
</table>
|
|
102
108
|
|
|
103
109
|
### 3.3 在模型脚本中开启工具
|
|
104
110
|
|
|
105
111
|
通过 PrecisionDebugger 统一接口开启工具,示例如下:
|
|
106
112
|
|
|
107
113
|
```python
|
|
108
|
-
|
|
114
|
+
import mindspore as ms
|
|
115
|
+
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")
|
|
116
|
+
|
|
117
|
+
# 其他模块的导入
|
|
118
|
+
# ...
|
|
109
119
|
|
|
120
|
+
from msprobe.mindspore import PrecisionDebugger
|
|
110
121
|
debugger = PrecisionDebugger(config_path='./config.json')
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
122
|
+
|
|
123
|
+
# 模型、损失函数的定义以及初始化等操作
|
|
124
|
+
# ...
|
|
125
|
+
model = Network()
|
|
126
|
+
|
|
127
|
+
# 数据集迭代的地方往往是模型开始训练的地方
|
|
128
|
+
for step, (input_ids, input_position, attention_mask) in enumerate(dataset.create_tuple_iterator()):
|
|
129
|
+
debugger.start() # 一般在训练循环开头启动工具
|
|
130
|
+
# 单步训练逻辑
|
|
131
|
+
# ...
|
|
132
|
+
debugger.stop() # 一般在训练循环末尾结束工具
|
|
133
|
+
debugger.step() # 在训练循环的最后需要重置工具,非循环场景不需要
|
|
116
134
|
```
|
|
117
135
|
|
|
118
136
|
### 3.4 查看精度风险算子
|
|
@@ -121,20 +139,21 @@ check 模式下,若存在不符合精度标准的 API,则工具会在 dump_p
|
|
|
121
139
|
|
|
122
140
|
| 字段 | 说明 |
|
|
123
141
|
| ------------ | ---------------------------------------------------------------------------------------- |
|
|
124
|
-
| rank | Rank ID,int
|
|
125
|
-
| pert_mode | 扰动因子的类型,string
|
|
126
|
-
| stage | 前/反向,string
|
|
127
|
-
| step | 迭代数,int
|
|
128
|
-
| api_name | API 名称,string
|
|
129
|
-
| max_rel | 输出对比最大相对误差,float
|
|
130
|
-
| dtype | 输入的 dtype,string
|
|
131
|
-
| shape | 输入的 shape,tuple
|
|
132
|
-
| output_index | 如果输出为列表或元组,其中一个元素检测不一致,则会有该元素的 index,否则为空,int
|
|
142
|
+
| rank | Rank ID,int 类型。 |
|
|
143
|
+
| pert_mode | 扰动因子的类型,string 类型。 |
|
|
144
|
+
| stage | 前/反向,string 类型。 |
|
|
145
|
+
| step | 迭代数,int 类型。 |
|
|
146
|
+
| api_name | API 名称,string 类型。 |
|
|
147
|
+
| max_rel | 输出对比最大相对误差,float 类型。 |
|
|
148
|
+
| dtype | 输入的 dtype,string 类型。 |
|
|
149
|
+
| shape | 输入的 shape,tuple 类型。 |
|
|
150
|
+
| output_index | 如果输出为列表或元组,其中一个元素检测不一致,则会有该元素的 index,否则为空,int 类型。 |
|
|
133
151
|
|
|
134
152
|
无标杆比对使用的精度标准如下:
|
|
135
153
|
|
|
136
|
-
| 输出dtype
|
|
137
|
-
|
|
|
138
|
-
| mindspore.
|
|
139
|
-
| mindspore.
|
|
140
|
-
|
|
|
154
|
+
| 输出dtype | 相对误差阈值 |
|
|
155
|
+
| ------------------ | ------------ |
|
|
156
|
+
| mindspore.bfloat16 | 0.004 |
|
|
157
|
+
| mindspore.float16 | 0.002 |
|
|
158
|
+
| mindspore.float32 | 0.0002 |
|
|
159
|
+
| mindspore.float64 | 0.0002 |
|
msprobe/docs/17.grad_probe.md
CHANGED
|
@@ -5,11 +5,11 @@
|
|
|
5
5
|
- 将模型权重的梯度数据导出。这种功能可以将模型权重的梯度值以统计量的形式采集出来,用以分析问题。
|
|
6
6
|
- 将两份梯度数据进行相似度对比。在有标杆问题中,可以确认训练过程中精度问题出现的step,以及抓取反向过程中的问题。
|
|
7
7
|
|
|
8
|
-
工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。
|
|
8
|
+
工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。暂不支持deepspeed的zero1、zero2、zero3。
|
|
9
9
|
|
|
10
10
|
## 工具特性
|
|
11
11
|
|
|
12
|
-
-
|
|
12
|
+
- 使用便捷,仅需在训练流程里插入少量代码
|
|
13
13
|
- 可以精准定位问题出现的step
|
|
14
14
|
|
|
15
15
|
## 使用方式
|
|
@@ -39,8 +39,8 @@
|
|
|
39
39
|
|--------------------------------|-----------------------------------|-----------------|----------|
|
|
40
40
|
| task | 填为"grad_probe"。 | str | 是 |
|
|
41
41
|
| dump_path | 输出目录。如果不存在就会创建一个新目录。 | str | 是 |
|
|
42
|
-
| rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。列表为空就表示导出所有rank
|
|
43
|
-
| step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step
|
|
42
|
+
| rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。列表为空就表示导出所有rank的数据。默认为空。采集特定 rank 时,须指定为训练脚本中存在的 rank_id,可逐个配置,也可以指定范围。<br/> **配置示例**:"rank": [0, 1 , 2, "4-6"]。(MindSpore静态图模式下,当前暂不支持指定rank功能) | list[Union[int, str]] | 否 |
|
|
43
|
+
| step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step的数据。默认为空。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。<br/> **配置示例**:"step": [0, 1 , 2, "4-6"]。(MindSpore静态图模式下,当前暂不支持指定step功能) | list[Union[int, str]] | 否 |
|
|
44
44
|
| grad_level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2。默认L1。|str | 否 |
|
|
45
45
|
| param_list | 权重名称列表,表示需要监控的权重。列表为空就表示监控所有权重。默认为空。 | List[str] | 否 |
|
|
46
46
|
| bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列,并且列表中的元素需要在int64取值范围内。可以使用默认值[-1, 0, 1]。 | List[float, int] | 否 |
|
|
@@ -100,11 +100,10 @@ debugger.stop()
|
|
|
100
100
|
│ ├── step{step}
|
|
101
101
|
│ │ ├── {param_name}.npy
|
|
102
102
|
```
|
|
103
|
-
+ {timestamp}:梯度工具导出数据的时候会在output_path下生成一个时间戳目录,然后在这个时间戳目录下输出结果。
|
|
104
103
|
+ rank_{rank_id}:在分布式场景下,会记录卡的rank_id。非分布式场景下,如果是CPU则记录进程号,如果是CPU或GPU则记录卡号
|
|
105
104
|
+ grad_summary_{step}.csv:会分step记录每一步的梯度数据统计值。
|
|
106
105
|
+ step_{step}:这个目录下会存放该step的梯度的方向数据。
|
|
107
|
-
+ {param_name}.
|
|
106
|
+
+ {param_name}.npy:模型参数的梯度方向数据。
|
|
108
107
|
|
|
109
108
|
**grad_summary_{step}.csv**
|
|
110
109
|
|