mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,300 +1,305 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
"""
|
|
4
|
-
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
"""
|
|
17
|
-
import
|
|
18
|
-
import os
|
|
19
|
-
import random
|
|
20
|
-
import stat
|
|
21
|
-
import
|
|
22
|
-
import
|
|
23
|
-
import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from
|
|
27
|
-
from msprobe.core.common.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
is_gpu =
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
return masked_select_func(input_tensor
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
if indices.dtype == torch.
|
|
66
|
-
indices
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
random.seed(seed)
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
torch.
|
|
117
|
-
torch.
|
|
118
|
-
torch.backends.cudnn.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
torch.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
pt = torch.load(pt_path
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import io
|
|
18
|
+
import os
|
|
19
|
+
import random
|
|
20
|
+
import stat
|
|
21
|
+
import torch
|
|
22
|
+
import torch.distributed as dist
|
|
23
|
+
import numpy as np
|
|
24
|
+
from functools import wraps
|
|
25
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
26
|
+
from msprobe.core.common.log import logger
|
|
27
|
+
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
28
|
+
check_file_or_directory_path, check_path_before_create)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import torch_npu
|
|
33
|
+
except ImportError:
|
|
34
|
+
is_gpu = True
|
|
35
|
+
else:
|
|
36
|
+
is_gpu = False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
if not is_gpu and not torch_without_guard_version:
|
|
43
|
+
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
44
|
+
|
|
45
|
+
npu_distributed_api = ['isend', 'irecv']
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def parameter_adapter(func):
|
|
49
|
+
|
|
50
|
+
def handle_masked_select(input_tensor, indices):
|
|
51
|
+
masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
|
|
52
|
+
if input_tensor.dtype == torch.bfloat16:
|
|
53
|
+
# masked_select在NPU上输入数据dtype类型为bfloat16会报错,提示不支持此类型
|
|
54
|
+
return masked_select_func(input_tensor.to(torch.float32), indices).to(torch.bfloat16)
|
|
55
|
+
else:
|
|
56
|
+
return masked_select_func(input_tensor, indices)
|
|
57
|
+
|
|
58
|
+
@wraps(func)
|
|
59
|
+
def inner(self, *args, **kwargs):
|
|
60
|
+
if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
|
|
61
|
+
input_tensor = args[0]
|
|
62
|
+
indices = args[1]
|
|
63
|
+
if indices.dtype == torch.uint8:
|
|
64
|
+
indices = indices.bool()
|
|
65
|
+
if indices.dtype == torch.bool:
|
|
66
|
+
if indices.shape == input_tensor.shape:
|
|
67
|
+
return handle_masked_select(input_tensor, indices)
|
|
68
|
+
else:
|
|
69
|
+
indices = getattr(torch._C._VariableFunctionsClass, "nonzero")(indices, as_tuple=True)
|
|
70
|
+
return getattr(torch._C._TensorBase, "__getitem__")(input_tensor, indices)
|
|
71
|
+
elif indices.dtype != torch.bool:
|
|
72
|
+
if not indices.shape or len(indices.shape) == 1:
|
|
73
|
+
return func(self, input_tensor, indices.tolist())
|
|
74
|
+
elif len(indices.shape) == 2:
|
|
75
|
+
result = [func(self, input_tensor, index) for index in indices.tolist()]
|
|
76
|
+
return getattr(torch._C._VariableFunctionsClass, "stack")(result, 0)
|
|
77
|
+
else:
|
|
78
|
+
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
79
|
+
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
80
|
+
if self.op_name_ == "__eq__" and args[1] is None:
|
|
81
|
+
return False
|
|
82
|
+
return func(self, *args, **kwargs)
|
|
83
|
+
return inner
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def torch_device_guard(func):
|
|
87
|
+
if is_gpu or torch_without_guard_version:
|
|
88
|
+
return func
|
|
89
|
+
# Parse args/kwargs matched torch.device objects
|
|
90
|
+
|
|
91
|
+
@torch_npu_device_guard
|
|
92
|
+
def wrapper(*args, **kwargs):
|
|
93
|
+
return func(*args, **kwargs)
|
|
94
|
+
return wrapper
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_rank_if_initialized():
|
|
98
|
+
"""
|
|
99
|
+
return rank id if it is initialized or raise Exception: DistributedNotInitializedError
|
|
100
|
+
"""
|
|
101
|
+
if torch.distributed.is_initialized():
|
|
102
|
+
return torch.distributed.get_rank()
|
|
103
|
+
else:
|
|
104
|
+
raise DistributedNotInitializedError("torch distributed environment is not initialized")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def seed_all(seed=1234, mode=False):
|
|
108
|
+
random.seed(seed)
|
|
109
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
110
|
+
np.random.seed(seed)
|
|
111
|
+
torch.manual_seed(seed)
|
|
112
|
+
torch.use_deterministic_algorithms(mode)
|
|
113
|
+
if is_gpu:
|
|
114
|
+
torch.cuda.manual_seed_all(seed)
|
|
115
|
+
torch.cuda.manual_seed(seed)
|
|
116
|
+
torch.backends.cudnn.deterministic = True
|
|
117
|
+
torch.backends.cudnn.enable = False
|
|
118
|
+
torch.backends.cudnn.benchmark = False
|
|
119
|
+
else:
|
|
120
|
+
torch_npu.npu.manual_seed_all(seed)
|
|
121
|
+
torch_npu.npu.manual_seed(seed)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Const:
|
|
125
|
+
"""
|
|
126
|
+
Class for const
|
|
127
|
+
"""
|
|
128
|
+
SEP = "."
|
|
129
|
+
MODEL_TYPE = ['.onnx', '.pb', '.om']
|
|
130
|
+
DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*"
|
|
131
|
+
SEMICOLON = ";"
|
|
132
|
+
COLON = ":"
|
|
133
|
+
EQUAL = "="
|
|
134
|
+
COMMA = ","
|
|
135
|
+
DOT = "."
|
|
136
|
+
DUMP_RATIO_MAX = 100
|
|
137
|
+
SUMMERY_DATA_NUMS = 256
|
|
138
|
+
FLOAT_EPSILON = np.finfo(float).eps
|
|
139
|
+
SUPPORT_DUMP_MODE = ['api', 'acl']
|
|
140
|
+
ON = 'ON'
|
|
141
|
+
OFF = 'OFF'
|
|
142
|
+
KWARGS = 'kwargs'
|
|
143
|
+
INPUT = 'input'
|
|
144
|
+
OUTPUT = 'output'
|
|
145
|
+
BACKWARD = 'backward'
|
|
146
|
+
FORWARD = 'forward'
|
|
147
|
+
PRE_FORWARD = "pre_forward"
|
|
148
|
+
INPUT_ARGS = 'input_args'
|
|
149
|
+
INPUT_KWARGS = 'input_kwargs'
|
|
150
|
+
GRAD_INPUT = 'grad_input'
|
|
151
|
+
GRAD_OUTPUT = 'grad_output'
|
|
152
|
+
START = "start"
|
|
153
|
+
STOP = "stop"
|
|
154
|
+
MAX = 'Max'
|
|
155
|
+
MIN = 'Min'
|
|
156
|
+
|
|
157
|
+
# dump mode
|
|
158
|
+
ALL = "all"
|
|
159
|
+
LIST = "list"
|
|
160
|
+
RANGE = "range"
|
|
161
|
+
STACK = "stack"
|
|
162
|
+
ACL = "acl"
|
|
163
|
+
API_LIST = "api_list"
|
|
164
|
+
API_STACK = "api_stack"
|
|
165
|
+
DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
|
|
166
|
+
AUTO = "auto"
|
|
167
|
+
ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
|
|
168
|
+
SUMMARY = "summary"
|
|
169
|
+
MD5 = "md5"
|
|
170
|
+
SUMMARY_MODE = [ALL, SUMMARY, MD5]
|
|
171
|
+
|
|
172
|
+
WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
|
|
173
|
+
OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
|
|
174
|
+
WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
|
|
175
|
+
|
|
176
|
+
PKL_SUFFIX = ".pkl"
|
|
177
|
+
NUMPY_SUFFIX = ".npy"
|
|
178
|
+
ONE_GB = 1 * 1024 * 1024 * 1024
|
|
179
|
+
TEN_GB = 10 * 1024 * 1024 * 1024
|
|
180
|
+
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
|
|
181
|
+
FILE_NAME_LENGTH = 255
|
|
182
|
+
DIRECTORY_LENGTH = 4096
|
|
183
|
+
DISTRIBUTED_PREFIX_LENGTH = 60
|
|
184
|
+
SUMMARY_COLUMN_NUM = 6
|
|
185
|
+
STACK_COLUMN_NUM = 2
|
|
186
|
+
# env dump path
|
|
187
|
+
ASCEND_WORK_PATH = "ASCEND_WORK_PATH"
|
|
188
|
+
DUMP_DIR = "dump_data"
|
|
189
|
+
DATA = "data"
|
|
190
|
+
|
|
191
|
+
ENV_ENABLE = "1"
|
|
192
|
+
ENV_DISABLE = "0"
|
|
193
|
+
|
|
194
|
+
MAX_SEED_VALUE = 2**32 - 1
|
|
195
|
+
|
|
196
|
+
INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
|
|
197
|
+
"_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
|
|
198
|
+
|
|
199
|
+
TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
|
|
200
|
+
LEVEL_LIST = ["L0", "L1", "L2", "mix"]
|
|
201
|
+
STATISTICS = "statistics"
|
|
202
|
+
TENSOR = "tensor"
|
|
203
|
+
OVERFLOW_CHECK = "overflow_check"
|
|
204
|
+
FREE_BENCHMARK = "free_benchmark"
|
|
205
|
+
|
|
206
|
+
ATTR_NAME_PREFIX = "wrap_"
|
|
207
|
+
|
|
208
|
+
FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
|
|
209
|
+
BOOL_TYPE = [bool, np.uint8]
|
|
210
|
+
INT_TYPE = [np.int32, np.int64]
|
|
211
|
+
NPU = 'NPU'
|
|
212
|
+
DISTRIBUTED = 'Distributed'
|
|
213
|
+
|
|
214
|
+
RAISE_PRECISION = {
|
|
215
|
+
torch.float16: torch.float32,
|
|
216
|
+
torch.bfloat16: torch.float32,
|
|
217
|
+
torch.float32: torch.float64
|
|
218
|
+
}
|
|
219
|
+
CONVERT = {
|
|
220
|
+
"int32_to_int64": ["torch.int32", "torch.int64"],
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
CONVERT_API = {
|
|
224
|
+
"int32_to_int64": ["cross_entropy"]
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_tensor_rank(in_feat, out_feat):
|
|
229
|
+
if dist.is_initialized():
|
|
230
|
+
return dist.get_rank()
|
|
231
|
+
|
|
232
|
+
def get_tensor_rank_single(x):
|
|
233
|
+
if isinstance(x, (list, tuple)):
|
|
234
|
+
if len(x) > 0:
|
|
235
|
+
return get_tensor_rank_single(x[0])
|
|
236
|
+
elif isinstance(x, torch.Tensor):
|
|
237
|
+
device = x.device
|
|
238
|
+
if device.type != 'cpu':
|
|
239
|
+
return device.index
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
in_rank = get_tensor_rank_single(in_feat)
|
|
243
|
+
out_rank = get_tensor_rank_single(out_feat)
|
|
244
|
+
tensor_rank = in_rank if in_rank else out_rank
|
|
245
|
+
return tensor_rank
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_rank_id():
|
|
249
|
+
if torch.distributed.is_initialized():
|
|
250
|
+
return torch.distributed.get_rank()
|
|
251
|
+
return 0
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def print_rank_0(message):
|
|
255
|
+
if dist.is_initialized():
|
|
256
|
+
if dist.get_rank() == 0:
|
|
257
|
+
logger.info(message)
|
|
258
|
+
else:
|
|
259
|
+
logger.info(message)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def load_pt(pt_path, to_cpu=False):
|
|
263
|
+
pt_path = os.path.realpath(pt_path)
|
|
264
|
+
check_file_or_directory_path(pt_path)
|
|
265
|
+
try:
|
|
266
|
+
if to_cpu:
|
|
267
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
268
|
+
else:
|
|
269
|
+
pt = torch.load(pt_path)
|
|
270
|
+
except Exception as e:
|
|
271
|
+
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
272
|
+
return pt
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def save_pt(tensor, filepath):
|
|
276
|
+
filepath = os.path.realpath(filepath)
|
|
277
|
+
check_path_before_create(filepath)
|
|
278
|
+
try:
|
|
279
|
+
torch.save(tensor, filepath)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
logger.error("Save pt file failed, please check according possible error causes: "
|
|
282
|
+
"1. out of disk space or disk error, "
|
|
283
|
+
"2. no permission to write files, etc.")
|
|
284
|
+
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
285
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def save_api_data(api_data):
|
|
289
|
+
"""Save data to io stream"""
|
|
290
|
+
try:
|
|
291
|
+
io_buff = io.BytesIO()
|
|
292
|
+
torch.save(api_data, io_buff)
|
|
293
|
+
except Exception as e:
|
|
294
|
+
raise RuntimeError(f"save api_data to io_buff failed") from e
|
|
295
|
+
return io_buff
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def load_api_data(api_data_bytes):
|
|
299
|
+
"""Load data from bytes stream"""
|
|
300
|
+
try:
|
|
301
|
+
buffer = io.BytesIO(api_data_bytes)
|
|
302
|
+
buffer = torch.load(buffer, map_location="cpu")
|
|
303
|
+
except Exception as e:
|
|
304
|
+
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
305
|
+
return buffer
|