mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/pytorch/compare/match.py
CHANGED
|
@@ -1,33 +1,49 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
if
|
|
29
|
-
return
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from msprobe.core.common.utils import CompareException
|
|
18
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AtenIrMapping():
|
|
22
|
+
def __init__(self):
|
|
23
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
24
|
+
yaml_path = os.path.join(cur_path, "mapping.yaml")
|
|
25
|
+
self.aten_mapping = load_yaml(yaml_path)
|
|
26
|
+
|
|
27
|
+
def match(self, op1, op2):
|
|
28
|
+
if "Aten" in op1 and "Aten" not in op2:
|
|
29
|
+
return self.match_op(op1, op2)
|
|
30
|
+
else:
|
|
31
|
+
return self.match_op(op2, op1)
|
|
32
|
+
|
|
33
|
+
def match_op(self, aten_op, torch_op):
|
|
34
|
+
try:
|
|
35
|
+
aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3])
|
|
36
|
+
aten_op_raw_name = aten_op_raw_name_overload.split('.')[0]
|
|
37
|
+
torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower()
|
|
38
|
+
except IndexError as e:
|
|
39
|
+
err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted."
|
|
40
|
+
raise CompareException.INVALID_DATA_ERROR(err_msg) from e
|
|
41
|
+
matching_op = self.aten_mapping.get(aten_op_raw_name)
|
|
42
|
+
if matching_op is None:
|
|
43
|
+
return False
|
|
44
|
+
if matching_op.lower() == torch_op_raw_name:
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
graph_mapping = AtenIrMapping()
|
|
@@ -1,40 +1,82 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os.path
|
|
17
|
+
import torch
|
|
18
|
+
from msprobe.core.common.const import FileCheckConst
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
21
|
+
from msprobe.core.compare.acc_compare import Comparator
|
|
22
|
+
from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
|
|
23
|
+
CompareException
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
25
|
+
from msprobe.pytorch.common.utils import load_pt
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PTComparator (Comparator):
|
|
29
|
+
def __init__(self, data_mapping=None):
|
|
30
|
+
self.frame_name = PTComparator.__name__
|
|
31
|
+
self.data_mapping = data_mapping
|
|
32
|
+
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
33
|
+
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
34
|
+
elif isinstance(self.data_mapping, dict):
|
|
35
|
+
self.data_mapping_dict = self.data_mapping
|
|
36
|
+
else:
|
|
37
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
|
+
f"{type(self.data_mapping)}")
|
|
39
|
+
|
|
40
|
+
def load_mapping_file(self, mapping_file):
|
|
41
|
+
if isinstance(mapping_file, str):
|
|
42
|
+
mapping_dict = load_yaml(mapping_file)
|
|
43
|
+
else:
|
|
44
|
+
mapping_dict = {}
|
|
45
|
+
return mapping_dict
|
|
46
|
+
|
|
47
|
+
def read_npy_data(self, dir_path, file_name):
|
|
48
|
+
data_path = os.path.join(dir_path, file_name)
|
|
49
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
50
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
51
|
+
data_path = path_checker.common_check()
|
|
52
|
+
try:
|
|
53
|
+
data_value = load_pt(data_path,
|
|
54
|
+
to_cpu=True).detach() # detach because numpy can not process gradient information
|
|
55
|
+
except RuntimeError as e:
|
|
56
|
+
# 这里捕获 load_pt 中抛出的异常
|
|
57
|
+
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
58
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
59
|
+
except AttributeError as e:
|
|
60
|
+
# 这里捕获 detach 方法抛出的异常
|
|
61
|
+
logger.error(f"Failed to detach the loaded tensor.")
|
|
62
|
+
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
63
|
+
if data_value.dtype == torch.bfloat16:
|
|
64
|
+
data_value = data_value.to(torch.float32)
|
|
65
|
+
data_value = data_value.numpy()
|
|
66
|
+
return data_value
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
|
|
70
|
+
try:
|
|
71
|
+
summary_compare, md5_compare = task_dumppath_get(input_param)
|
|
72
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
73
|
+
create_directory(output_path)
|
|
74
|
+
check_compare_param(input_param, output_path, summary_compare, md5_compare)
|
|
75
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
76
|
+
except (CompareException, FileCheckException) as error:
|
|
77
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
78
|
+
raise CompareException(error.code) from error
|
|
79
|
+
pt_comparator = PTComparator(data_mapping)
|
|
80
|
+
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
81
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
82
|
+
md5_compare=md5_compare)
|
|
@@ -1,95 +1,108 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self.
|
|
26
|
-
|
|
27
|
-
if
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
self.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
self.
|
|
44
|
-
self.
|
|
45
|
-
self.
|
|
46
|
-
self.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def
|
|
92
|
-
if self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DebuggerConfig:
|
|
24
|
+
def __init__(self, common_config, task_config, task, dump_path, level):
|
|
25
|
+
self.dump_path = dump_path if dump_path else common_config.dump_path
|
|
26
|
+
self.task = task or common_config.task or Const.STATISTICS
|
|
27
|
+
self.rank = common_config.rank if common_config.rank else []
|
|
28
|
+
self.step = common_config.step if common_config.step else []
|
|
29
|
+
self.level = level or common_config.level or "L1"
|
|
30
|
+
self.enable_dataloader = common_config.enable_dataloader
|
|
31
|
+
self.scope = task_config.scope if task_config.scope else []
|
|
32
|
+
self.list = task_config.list if task_config.list else []
|
|
33
|
+
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
34
|
+
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
35
|
+
self.backward_input = {}
|
|
36
|
+
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
37
|
+
self.is_forward_acl_dump = True
|
|
38
|
+
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
39
|
+
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
40
|
+
self.framework = Const.PT_FRAMEWORK
|
|
41
|
+
|
|
42
|
+
if self.task == Const.FREE_BENCHMARK:
|
|
43
|
+
self.fuzz_device = task_config.fuzz_device
|
|
44
|
+
self.handler_type = task_config.handler_type
|
|
45
|
+
self.pert_mode = task_config.pert_mode
|
|
46
|
+
self.fuzz_level = task_config.fuzz_level
|
|
47
|
+
self.fuzz_stage = task_config.fuzz_stage
|
|
48
|
+
self.preheat_config = {
|
|
49
|
+
"if_preheat": task_config.if_preheat,
|
|
50
|
+
"preheat_step": task_config.preheat_step,
|
|
51
|
+
"max_sample": task_config.max_sample
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
self.online_run_ut = False
|
|
55
|
+
if self.task == Const.TENSOR:
|
|
56
|
+
# dump api tensor and collaborate with online run_ut
|
|
57
|
+
self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
|
|
58
|
+
self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
|
|
59
|
+
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
60
|
+
self.host = task_config.host if task_config.host else ""
|
|
61
|
+
self.port = task_config.port if task_config.port else -1
|
|
62
|
+
|
|
63
|
+
self.check()
|
|
64
|
+
|
|
65
|
+
if self.level == "L2":
|
|
66
|
+
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
67
|
+
raise ValueError("scope must be configured as a list with one api name")
|
|
68
|
+
if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
|
|
69
|
+
raise ValueError("backward_input must be configured when scope contains 'backward'")
|
|
70
|
+
if Const.BACKWARD in self.scope[0]:
|
|
71
|
+
self.is_forward_acl_dump = False
|
|
72
|
+
for index, scope_spec in enumerate(self.scope):
|
|
73
|
+
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
74
|
+
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
75
|
+
|
|
76
|
+
def check_kwargs(self):
|
|
77
|
+
if self.task and self.task not in Const.TASK_LIST:
|
|
78
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
79
|
+
f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
|
|
80
|
+
if self.level and self.level not in Const.LEVEL_LIST:
|
|
81
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
82
|
+
f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
|
|
83
|
+
if not self.dump_path:
|
|
84
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
85
|
+
f"The dump_path not found.")
|
|
86
|
+
|
|
87
|
+
def check(self):
|
|
88
|
+
self.check_kwargs()
|
|
89
|
+
return True
|
|
90
|
+
|
|
91
|
+
def check_model(self, instance, start_model):
|
|
92
|
+
if self.level not in ["L0", "mix"]:
|
|
93
|
+
if instance.model is not None or start_model is not None:
|
|
94
|
+
logger.warning_on_rank_0(
|
|
95
|
+
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
96
|
+
return
|
|
97
|
+
if start_model is None:
|
|
98
|
+
if instance.model is None:
|
|
99
|
+
logger.error_on_rank_0(
|
|
100
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
|
|
101
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
102
|
+
return
|
|
103
|
+
if isinstance(start_model, torch.nn.Module):
|
|
104
|
+
instance.model = start_model
|
|
105
|
+
else:
|
|
106
|
+
logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
|
|
107
|
+
raise MsprobeException(
|
|
108
|
+
MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
|