mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -1,616 +1,371 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
"""
|
|
4
|
-
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
"""
|
|
17
|
-
import collections
|
|
18
|
-
import os
|
|
19
|
-
import re
|
|
20
|
-
import
|
|
21
|
-
import
|
|
22
|
-
import
|
|
23
|
-
import
|
|
24
|
-
|
|
25
|
-
from
|
|
26
|
-
from
|
|
27
|
-
import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
raise
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
if
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
raise
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
def format_value(value):
|
|
376
|
-
return float('{:.12f}'.format(value))
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
def check_seed_all(seed, mode):
|
|
380
|
-
if isinstance(seed, int):
|
|
381
|
-
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
382
|
-
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
383
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
384
|
-
else:
|
|
385
|
-
logger.error(f"Seed must be integer.")
|
|
386
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
387
|
-
if not isinstance(mode, bool):
|
|
388
|
-
logger.error(f"seed_all mode must be bool.")
|
|
389
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
def get_process_rank(model):
|
|
393
|
-
logger.info("Rank id is not provided. Trying to get the rank id of the model.")
|
|
394
|
-
try:
|
|
395
|
-
local_device = next(model.parameters()).device
|
|
396
|
-
except StopIteration:
|
|
397
|
-
logger.warning('There is no parameter in the model. Fail to get rank id.')
|
|
398
|
-
return 0, False
|
|
399
|
-
if local_device.type == 'cpu':
|
|
400
|
-
logger.warning("Warning: the debugger is unable to get the rank id. "
|
|
401
|
-
"This may cause the dumpped data to be corrupted in the "
|
|
402
|
-
"case of distributed training. (You may ignore this if you are using only one card.) "
|
|
403
|
-
"Transfer the model to npu or gpu before register_hook() to avoid this warning.")
|
|
404
|
-
return 0, False
|
|
405
|
-
else:
|
|
406
|
-
return local_device.index, True
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
|
|
410
|
-
template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
|
|
411
|
-
pkl_dir = os.path.dirname(pkl_file_path)
|
|
412
|
-
compare_script_path = os.path.join(pkl_dir, "compare_data.py")
|
|
413
|
-
is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
|
|
414
|
-
|
|
415
|
-
try:
|
|
416
|
-
with FileOpen(template_path, 'r') as ftemp, \
|
|
417
|
-
os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
|
|
418
|
-
code_temp = ftemp.read()
|
|
419
|
-
fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
|
|
420
|
-
except OSError:
|
|
421
|
-
logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
|
|
422
|
-
|
|
423
|
-
logger.info(f"Generate compare script successfully which is {compare_script_path}.")
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def check_file_valid(file_path):
|
|
427
|
-
if os.path.islink(file_path):
|
|
428
|
-
logger.error('The file path {} is a soft link.'.format(file_path))
|
|
429
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
430
|
-
|
|
431
|
-
if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \
|
|
432
|
-
Const.FILE_NAME_LENGTH:
|
|
433
|
-
logger.error('The file path length exceeds limit.')
|
|
434
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
435
|
-
|
|
436
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)):
|
|
437
|
-
logger.error('The file path {} contains special characters.'.format(file_path))
|
|
438
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
439
|
-
|
|
440
|
-
if os.path.isfile(file_path):
|
|
441
|
-
file_size = os.path.getsize(file_path)
|
|
442
|
-
if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
|
|
443
|
-
logger.error('The file {} size is greater than 1GB.'.format(file_path))
|
|
444
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
445
|
-
if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB:
|
|
446
|
-
logger.error('The file {} size is greater than 10GB.'.format(file_path))
|
|
447
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
def check_path_before_create(path):
|
|
451
|
-
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
452
|
-
Const.FILE_NAME_LENGTH:
|
|
453
|
-
logger.error('The file path length exceeds limit.')
|
|
454
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
455
|
-
|
|
456
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
457
|
-
logger.error('The file path {} contains special characters.'.format(path))
|
|
458
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
def check_inplace_op(prefix):
|
|
462
|
-
if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
|
|
463
|
-
return False
|
|
464
|
-
match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
|
|
465
|
-
op_name = match_op[0] if match_op else None
|
|
466
|
-
return op_name in Const.INPLACE_LIST
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
def md5_find(data):
|
|
470
|
-
for key_op in data:
|
|
471
|
-
for api_info in data[key_op]:
|
|
472
|
-
if isinstance(data[key_op][api_info], list):
|
|
473
|
-
for data_detail in data[key_op][api_info]:
|
|
474
|
-
if data_detail and 'md5' in data_detail:
|
|
475
|
-
return True
|
|
476
|
-
elif 'md5' in data[key_op][api_info]:
|
|
477
|
-
return True
|
|
478
|
-
return False
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
def task_dumppath_get(input_param):
|
|
482
|
-
npu_path = input_param.get("npu_json_path", None)
|
|
483
|
-
bench_path = input_param.get("bench_json_path", None)
|
|
484
|
-
if not npu_path or not bench_path:
|
|
485
|
-
logger.error(f"Please check the json path is valid.")
|
|
486
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
487
|
-
with FileOpen(npu_path, 'r') as npu_f:
|
|
488
|
-
npu_json_data = json.load(npu_f)
|
|
489
|
-
with FileOpen(bench_path, 'r') as bench_f:
|
|
490
|
-
bench_json_data = json.load(bench_f)
|
|
491
|
-
if npu_json_data['task'] != bench_json_data['task']:
|
|
492
|
-
logger.error(f"Please check the dump task is consistent.")
|
|
493
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
494
|
-
if npu_json_data['task'] == Const.TENSOR:
|
|
495
|
-
summary_compare = False
|
|
496
|
-
md5_compare = False
|
|
497
|
-
elif npu_json_data['task'] == Const.STATISTICS:
|
|
498
|
-
md5_compare = md5_find(npu_json_data['data'])
|
|
499
|
-
if md5_compare:
|
|
500
|
-
summary_compare = False
|
|
501
|
-
else:
|
|
502
|
-
summary_compare = True
|
|
503
|
-
else:
|
|
504
|
-
logger.error(f"Compare is not required for overflow_check or free_benchmark.")
|
|
505
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
506
|
-
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
507
|
-
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
508
|
-
return summary_compare, md5_compare
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
def get_header_index(header_name, summary_compare=False):
|
|
512
|
-
if summary_compare:
|
|
513
|
-
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
514
|
-
else:
|
|
515
|
-
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
516
|
-
if header_name not in header:
|
|
517
|
-
logger.error(f"{header_name} not in data name")
|
|
518
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
519
|
-
return header.index(header_name)
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
def convert_tuple(data):
|
|
523
|
-
return data if isinstance(data, tuple) else (data, )
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
def write_csv(data, filepath, mode="a+"):
|
|
527
|
-
exist = os.path.exists(filepath)
|
|
528
|
-
with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
|
|
529
|
-
writer = csv.writer(f)
|
|
530
|
-
writer.writerows(data)
|
|
531
|
-
if not exist:
|
|
532
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
def load_npy(filepath):
|
|
536
|
-
check_file_or_directory_path(filepath)
|
|
537
|
-
try:
|
|
538
|
-
npy = np.load(filepath)
|
|
539
|
-
except Exception as e:
|
|
540
|
-
logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
|
|
541
|
-
raise RuntimeError(f"Load numpy file {filepath} failed.") from e
|
|
542
|
-
return npy
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
def save_npy(data, filepath):
|
|
546
|
-
filepath = os.path.realpath(filepath)
|
|
547
|
-
check_path_before_create(filepath)
|
|
548
|
-
try:
|
|
549
|
-
np.save(filepath, data)
|
|
550
|
-
except Exception as e:
|
|
551
|
-
logger.error(f"The numpy file failed to save. Please check the path: {filepath}.")
|
|
552
|
-
raise RuntimeError(f"Save numpy file {filepath} failed.") from e
|
|
553
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
554
|
-
|
|
555
|
-
def save_npy_to_txt(self, data, dst_file='', align=0):
|
|
556
|
-
if os.path.exists(dst_file):
|
|
557
|
-
self.log.info("Dst file %s exists, will not save new one.", dst_file)
|
|
558
|
-
return
|
|
559
|
-
shape = data.shape
|
|
560
|
-
data = data.flatten()
|
|
561
|
-
if align == 0:
|
|
562
|
-
align = 1 if len(shape) == 0 else shape[-1]
|
|
563
|
-
elif data.size % align != 0:
|
|
564
|
-
pad_array = np.zeros((align - data.size % align,))
|
|
565
|
-
data = np.append(data, pad_array)
|
|
566
|
-
check_path_before_create(dst_file)
|
|
567
|
-
try:
|
|
568
|
-
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
569
|
-
except Exception as e:
|
|
570
|
-
self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
|
|
571
|
-
change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
572
|
-
|
|
573
|
-
def get_json_contents(file_path):
|
|
574
|
-
ops = get_file_content_bytes(file_path)
|
|
575
|
-
try:
|
|
576
|
-
json_obj = json.loads(ops)
|
|
577
|
-
except ValueError as error:
|
|
578
|
-
logger.error('Failed to load json.')
|
|
579
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR) from error
|
|
580
|
-
if not isinstance(json_obj, dict):
|
|
581
|
-
logger.error('Json file content is not a dictionary!')
|
|
582
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR)
|
|
583
|
-
return json_obj
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
def get_file_content_bytes(file):
|
|
587
|
-
with FileOpen(file, 'rb') as file_handle:
|
|
588
|
-
return file_handle.read()
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
def load_yaml(yaml_path):
|
|
592
|
-
path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX)
|
|
593
|
-
checked_path = path_checker.common_check()
|
|
594
|
-
try:
|
|
595
|
-
with FileOpen(checked_path, "r") as f:
|
|
596
|
-
yaml_data = yaml.safe_load(f)
|
|
597
|
-
except Exception as e:
|
|
598
|
-
logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.")
|
|
599
|
-
raise RuntimeError(f"Load yaml file {checked_path} failed.") from e
|
|
600
|
-
return yaml_data
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def save_workbook(workbook, file_path):
|
|
604
|
-
"""
|
|
605
|
-
保存工作簿到指定的文件路径
|
|
606
|
-
workbook: 要保存的工作簿对象
|
|
607
|
-
file_path: 文件保存路径
|
|
608
|
-
"""
|
|
609
|
-
file_path = os.path.realpath(file_path)
|
|
610
|
-
check_path_before_create(file_path)
|
|
611
|
-
try:
|
|
612
|
-
workbook.save(file_path)
|
|
613
|
-
except Exception as e:
|
|
614
|
-
logger.error(f'Save result file "{os.path.basename(file_path)}" failed')
|
|
615
|
-
raise CompareException(CompareException.WRITE_FILE_ERROR) from e
|
|
616
|
-
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import collections
|
|
18
|
+
import os
|
|
19
|
+
import re
|
|
20
|
+
import subprocess
|
|
21
|
+
import time
|
|
22
|
+
import json
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
|
+
|
|
25
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
26
|
+
from msprobe.core.common.const import Const, CompareConst
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
device = collections.namedtuple('device', ['type', 'index'])
|
|
32
|
+
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MsprobeBaseException(Exception):
|
|
36
|
+
"""
|
|
37
|
+
Base class for all custom exceptions.
|
|
38
|
+
"""
|
|
39
|
+
# 所有的错误代码
|
|
40
|
+
NONE_ERROR = 0
|
|
41
|
+
INVALID_PATH_ERROR = 1
|
|
42
|
+
OPEN_FILE_ERROR = 2
|
|
43
|
+
CLOSE_FILE_ERROR = 3
|
|
44
|
+
READ_FILE_ERROR = 4
|
|
45
|
+
WRITE_FILE_ERROR = 5
|
|
46
|
+
INVALID_FILE_ERROR = 6
|
|
47
|
+
PERMISSION_ERROR = 7
|
|
48
|
+
INDEX_OUT_OF_BOUNDS_ERROR = 8
|
|
49
|
+
NO_DUMP_FILE_ERROR = 9
|
|
50
|
+
INVALID_DATA_ERROR = 10
|
|
51
|
+
INVALID_PARAM_ERROR = 11
|
|
52
|
+
INVALID_DUMP_RATIO = 12
|
|
53
|
+
INVALID_DUMP_FILE = 13
|
|
54
|
+
UNKNOWN_ERROR = 14
|
|
55
|
+
INVALID_DUMP_MODE = 15
|
|
56
|
+
PARSE_FILE_ERROR = 16
|
|
57
|
+
INVALID_COMPARE_MODE = 17
|
|
58
|
+
OVER_SIZE_FILE_ERROR = 18
|
|
59
|
+
INVALID_SUMMARY_MODE = 19
|
|
60
|
+
INVALID_TASK_ERROR = 20
|
|
61
|
+
DETACH_ERROR = 21
|
|
62
|
+
INVALID_OBJECT_TYPE_ERROR = 22
|
|
63
|
+
INVALID_CHAR_ERROR = 23
|
|
64
|
+
RECURSION_LIMIT_ERROR = 24
|
|
65
|
+
INVALID_ATTRIBUTE_ERROR = 25
|
|
66
|
+
OUTPUT_HOOK_ERROR = 26
|
|
67
|
+
INPUT_HOOK_ERROR = 27
|
|
68
|
+
FUNCTION_CALL_ERROR = 28
|
|
69
|
+
FORWARD_DATA_COLLECTION_ERROR = 29
|
|
70
|
+
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
71
|
+
|
|
72
|
+
def __init__(self, code, error_info: str = ""):
|
|
73
|
+
super(MsprobeBaseException, self).__init__()
|
|
74
|
+
self.code = code
|
|
75
|
+
self.error_info = error_info
|
|
76
|
+
|
|
77
|
+
def __str__(self):
|
|
78
|
+
return self.error_info
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class CompareException(MsprobeBaseException):
|
|
82
|
+
"""
|
|
83
|
+
Class for Accuracy Compare Exception
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, code, error_info: str = ""):
|
|
87
|
+
super(CompareException, self).__init__(code, error_info)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DumpException(MsprobeBaseException):
|
|
91
|
+
"""
|
|
92
|
+
Class for Dump Exception
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, code, error_info: str = ""):
|
|
96
|
+
super(DumpException, self).__init__(code, error_info)
|
|
97
|
+
|
|
98
|
+
def __str__(self):
|
|
99
|
+
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
|
|
103
|
+
if not isinstance(input_param, dict):
|
|
104
|
+
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
105
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
106
|
+
if not isinstance(output_path, str):
|
|
107
|
+
logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
|
|
108
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
109
|
+
|
|
110
|
+
check_file_or_directory_path(input_param.get("npu_json_path"), False)
|
|
111
|
+
check_file_or_directory_path(input_param.get("bench_json_path"), False)
|
|
112
|
+
check_file_or_directory_path(input_param.get("stack_json_path"), False)
|
|
113
|
+
if not summary_compare and not md5_compare:
|
|
114
|
+
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
115
|
+
check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
|
|
116
|
+
check_file_or_directory_path(output_path, True)
|
|
117
|
+
|
|
118
|
+
with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
|
|
119
|
+
FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
|
|
120
|
+
FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
|
|
121
|
+
check_json_file(input_param, npu_json, bench_json, stack_json)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
125
|
+
arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
|
|
126
|
+
for arg in arg_list:
|
|
127
|
+
if not isinstance(arg, bool):
|
|
128
|
+
logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
|
|
129
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _check_json(json_file_handle, file_name):
|
|
133
|
+
tensor_line = json_file_handle.readline()
|
|
134
|
+
if not tensor_line:
|
|
135
|
+
logger.error("dump file {} have empty line!".format(file_name))
|
|
136
|
+
raise CompareException(CompareException.INVALID_DUMP_FILE)
|
|
137
|
+
json_file_handle.seek(0, 0)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def check_json_file(input_param, npu_json, bench_json, stack_json):
|
|
141
|
+
_check_json(npu_json, input_param.get("npu_json_path"))
|
|
142
|
+
_check_json(bench_json, input_param.get("bench_json_path"))
|
|
143
|
+
_check_json(stack_json, input_param.get("stack_json_path"))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def check_regex_prefix_format_valid(prefix):
|
|
147
|
+
"""
|
|
148
|
+
validate the format of the regex prefix
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
prefix (str): The prefix string to validate.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
no returns
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
|
|
158
|
+
the given pattern Const.REGEX_PREFIX_PATTERN
|
|
159
|
+
"""
|
|
160
|
+
if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
|
|
161
|
+
raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
|
|
162
|
+
f"is {len(prefix)}")
|
|
163
|
+
if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
|
|
164
|
+
raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def execute_command(cmd):
|
|
168
|
+
"""
|
|
169
|
+
Function Description:
|
|
170
|
+
run the following command
|
|
171
|
+
Parameter:
|
|
172
|
+
cmd: command
|
|
173
|
+
Exception Description:
|
|
174
|
+
when invalid command throw exception
|
|
175
|
+
"""
|
|
176
|
+
logger.info('Execute command:%s' % cmd)
|
|
177
|
+
process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
178
|
+
while process.poll() is None:
|
|
179
|
+
line = process.stdout.readline()
|
|
180
|
+
line = line.strip()
|
|
181
|
+
if line:
|
|
182
|
+
print(line)
|
|
183
|
+
if process.returncode != 0:
|
|
184
|
+
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
185
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def add_time_as_suffix(name):
|
|
189
|
+
return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def add_time_with_xlsx(name):
|
|
193
|
+
return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def add_time_with_yaml(name):
|
|
197
|
+
return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_time():
|
|
201
|
+
return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def format_value(value):
|
|
205
|
+
return float('{:.12f}'.format(value))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def md5_find(data):
|
|
209
|
+
for key_op in data:
|
|
210
|
+
for api_info in data[key_op]:
|
|
211
|
+
if isinstance(data[key_op][api_info], list):
|
|
212
|
+
for data_detail in data[key_op][api_info]:
|
|
213
|
+
if data_detail and 'md5' in data_detail:
|
|
214
|
+
return True
|
|
215
|
+
elif 'md5' in data[key_op][api_info]:
|
|
216
|
+
return True
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def struct_json_get(input_param, framework):
|
|
221
|
+
if framework == Const.PT_FRAMEWORK:
|
|
222
|
+
prefix = "bench"
|
|
223
|
+
elif framework == Const.MS_FRAMEWORK:
|
|
224
|
+
prefix = "npu"
|
|
225
|
+
else:
|
|
226
|
+
logger.error("Error framework found.")
|
|
227
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
228
|
+
|
|
229
|
+
frame_json_path = input_param.get(f"{prefix}_json_path", None)
|
|
230
|
+
if not frame_json_path:
|
|
231
|
+
logger.error(f"Please check the json path is valid.")
|
|
232
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
233
|
+
directory = os.path.dirname(frame_json_path)
|
|
234
|
+
check_file_or_directory_path(directory, True)
|
|
235
|
+
stack_json = os.path.join(directory, "stack.json")
|
|
236
|
+
construct_json = os.path.join(directory, "construct.json")
|
|
237
|
+
|
|
238
|
+
stack = load_json(stack_json)
|
|
239
|
+
construct = load_json(construct_json)
|
|
240
|
+
return stack, construct
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def task_dumppath_get(input_param):
|
|
244
|
+
npu_path = input_param.get("npu_json_path", None)
|
|
245
|
+
bench_path = input_param.get("bench_json_path", None)
|
|
246
|
+
if not npu_path or not bench_path:
|
|
247
|
+
logger.error(f"Please check the json path is valid.")
|
|
248
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
249
|
+
with FileOpen(npu_path, 'r') as npu_f:
|
|
250
|
+
npu_json_data = json.load(npu_f)
|
|
251
|
+
with FileOpen(bench_path, 'r') as bench_f:
|
|
252
|
+
bench_json_data = json.load(bench_f)
|
|
253
|
+
if npu_json_data['task'] != bench_json_data['task']:
|
|
254
|
+
logger.error(f"Please check the dump task is consistent.")
|
|
255
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
256
|
+
if npu_json_data['task'] == Const.TENSOR:
|
|
257
|
+
summary_compare = False
|
|
258
|
+
md5_compare = False
|
|
259
|
+
elif npu_json_data['task'] == Const.STATISTICS:
|
|
260
|
+
md5_compare = md5_find(npu_json_data['data'])
|
|
261
|
+
if md5_compare:
|
|
262
|
+
summary_compare = False
|
|
263
|
+
else:
|
|
264
|
+
summary_compare = True
|
|
265
|
+
else:
|
|
266
|
+
logger.error(f"Compare is not required for overflow_check or free_benchmark.")
|
|
267
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
268
|
+
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
269
|
+
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
270
|
+
return summary_compare, md5_compare
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def get_header_index(header_name, summary_compare=False):
|
|
274
|
+
if summary_compare:
|
|
275
|
+
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
276
|
+
else:
|
|
277
|
+
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
278
|
+
if header_name not in header:
|
|
279
|
+
logger.error(f"{header_name} not in data name")
|
|
280
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
281
|
+
return header.index(header_name)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def convert_tuple(data):
|
|
285
|
+
return data if isinstance(data, tuple) else (data, )
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def check_op_str_pattern_valid(string, op_name=None, stack=False):
|
|
289
|
+
if isinstance(string, str) and is_invalid_pattern(string):
|
|
290
|
+
if stack:
|
|
291
|
+
message = f"stack info of {op_name} contains special characters, please check!"
|
|
292
|
+
elif not op_name:
|
|
293
|
+
message = f"api name contains special characters, please check!"
|
|
294
|
+
else:
|
|
295
|
+
message = f"data info of {op_name} contains special characters, please check!"
|
|
296
|
+
logger.error(message)
|
|
297
|
+
raise CompareException(CompareException.INVALID_CHAR_ERROR)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def is_invalid_pattern(string):
|
|
301
|
+
pattern = Const.STRING_BLACKLIST
|
|
302
|
+
return re.search(pattern, string)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def print_tools_ends_info():
|
|
306
|
+
total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
307
|
+
logger.info('*' * total_len)
|
|
308
|
+
logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
309
|
+
logger.info('*' * total_len)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def get_step_or_rank_from_string(step_or_rank, obj):
|
|
313
|
+
splited = step_or_rank.split(Const.HYPHEN)
|
|
314
|
+
if len(splited) == 2:
|
|
315
|
+
try:
|
|
316
|
+
borderlines = int(splited[0]), int(splited[1])
|
|
317
|
+
except (ValueError, IndexError) as e:
|
|
318
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
319
|
+
"The hyphen(-) must start and end with decimal numbers.") from e
|
|
320
|
+
else:
|
|
321
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
322
|
+
f'The string parameter for {obj} only supports formats like "3-5". Now string parameter for {obj} is "{step_or_rank}".')
|
|
323
|
+
if all(Const.STEP_RANK_MAXIMUM_RANGE[0] <= b <= Const.STEP_RANK_MAXIMUM_RANGE[1] for b in borderlines):
|
|
324
|
+
if borderlines[0] <= borderlines[1]:
|
|
325
|
+
continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
|
|
326
|
+
else:
|
|
327
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
328
|
+
f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be greater than the right boundary ({borderlines[1]}).')
|
|
329
|
+
else:
|
|
330
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
331
|
+
f"The boundaries must fall within the range of [{Const.STEP_RANK_MAXIMUM_RANGE[0]}, {Const.STEP_RANK_MAXIMUM_RANGE[1]}].")
|
|
332
|
+
return continual_step_or_rank
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def get_real_step_or_rank(step_or_rank_input, obj):
|
|
336
|
+
if obj not in [Const.STEP, Const.RANK]:
|
|
337
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
338
|
+
f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
|
|
339
|
+
if step_or_rank_input is None:
|
|
340
|
+
return []
|
|
341
|
+
if not isinstance(step_or_rank_input, list):
|
|
342
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
|
|
343
|
+
real_step_or_rank = []
|
|
344
|
+
for element in step_or_rank_input:
|
|
345
|
+
if not isinstance(element, (int, str)):
|
|
346
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
347
|
+
f"{obj} element {element} must be an integer or string.")
|
|
348
|
+
if isinstance(element, int) and element < 0:
|
|
349
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
350
|
+
f"Each element of {obj} must be non-negative, currently it is {element}.")
|
|
351
|
+
if isinstance(element, int) and Const.STEP_RANK_MAXIMUM_RANGE[0] <= element <= Const.STEP_RANK_MAXIMUM_RANGE[1]:
|
|
352
|
+
real_step_or_rank.append(element)
|
|
353
|
+
elif isinstance(element, str) and Const.HYPHEN in element:
|
|
354
|
+
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
355
|
+
real_step_or_rank.extend(continual_step_or_rank)
|
|
356
|
+
real_step_or_rank = list(set(real_step_or_rank))
|
|
357
|
+
real_step_or_rank.sort()
|
|
358
|
+
return real_step_or_rank
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def check_seed_all(seed, mode):
|
|
362
|
+
if isinstance(seed, int):
|
|
363
|
+
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
364
|
+
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
365
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
366
|
+
else:
|
|
367
|
+
logger.error("Seed must be integer.")
|
|
368
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
369
|
+
if not isinstance(mode, bool):
|
|
370
|
+
logger.error("seed_all mode must be bool.")
|
|
371
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|