mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/mindspore/common/log.py
CHANGED
|
@@ -1,38 +1,38 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
import time
|
|
18
|
-
import sys
|
|
19
|
-
|
|
20
|
-
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
21
|
-
from msprobe.core.common.log import BaseLogger
|
|
22
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class MindsporeLogger(BaseLogger):
|
|
26
|
-
def __init__(self):
|
|
27
|
-
super().__init__()
|
|
28
|
-
|
|
29
|
-
def get_rank(self):
|
|
30
|
-
try:
|
|
31
|
-
current_rank = get_rank_if_initialized()
|
|
32
|
-
except DistributedNotInitializedError:
|
|
33
|
-
current_rank = None
|
|
34
|
-
|
|
35
|
-
return current_rank
|
|
36
|
-
|
|
37
|
-
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
import sys
|
|
19
|
+
|
|
20
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
21
|
+
from msprobe.core.common.log import BaseLogger
|
|
22
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MindsporeLogger(BaseLogger):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
def get_rank(self):
|
|
30
|
+
try:
|
|
31
|
+
current_rank = get_rank_if_initialized()
|
|
32
|
+
except DistributedNotInitializedError:
|
|
33
|
+
current_rank = None
|
|
34
|
+
|
|
35
|
+
return current_rank
|
|
36
|
+
|
|
37
|
+
|
|
38
38
|
logger = MindsporeLogger()
|
|
@@ -1,57 +1,97 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
from msprobe.core.common.
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
|
|
19
|
+
import mindspore as ms
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
22
|
+
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.const import Const
|
|
25
|
+
from msprobe.core.common.utils import CompareException, check_seed_all
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_rank_if_initialized():
|
|
29
|
+
if ms.communication.GlobalComm.INITED:
|
|
30
|
+
return ms.communication.get_rank()
|
|
31
|
+
else:
|
|
32
|
+
raise DistributedNotInitializedError("mindspore distributed environment is not initialized")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_bf16_to_fp32(tensor):
|
|
36
|
+
if tensor.dtype == ms.bfloat16:
|
|
37
|
+
tensor = tensor.to(ms.float32)
|
|
38
|
+
return tensor
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_tensor_as_npy(tensor, file_path):
|
|
42
|
+
if not path_len_exceeds_limit(file_path):
|
|
43
|
+
tensor = convert_bf16_to_fp32(tensor)
|
|
44
|
+
saved_tensor = tensor.asnumpy()
|
|
45
|
+
save_npy(saved_tensor, file_path)
|
|
46
|
+
else:
|
|
47
|
+
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def convert_to_int(value):
|
|
51
|
+
try:
|
|
52
|
+
return int(value)
|
|
53
|
+
except Exception:
|
|
54
|
+
return -1
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def list_lowest_level_directories(root_dir):
|
|
58
|
+
check_path_exists(root_dir)
|
|
59
|
+
lowest_level_dirs = []
|
|
60
|
+
|
|
61
|
+
def recurse_dirs(current_dir, depth=0):
|
|
62
|
+
if depth > Const.MAX_DEPTH:
|
|
63
|
+
logger.error(f'The directory {current_dir} has more than {Const.MAX_DEPTH} levels.')
|
|
64
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
65
|
+
for entry in os.listdir(current_dir):
|
|
66
|
+
full_path = os.path.join(current_dir, entry)
|
|
67
|
+
if os.path.isdir(full_path):
|
|
68
|
+
if any(os.path.isdir(os.path.join(full_path, subentry)) for subentry in os.listdir(full_path)):
|
|
69
|
+
recurse_dirs(full_path, depth=depth+1)
|
|
70
|
+
else:
|
|
71
|
+
lowest_level_dirs.append(full_path)
|
|
72
|
+
|
|
73
|
+
recurse_dirs(root_dir)
|
|
74
|
+
return lowest_level_dirs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def seed_all(seed=1234, mode=False):
|
|
78
|
+
check_seed_all(seed, mode)
|
|
79
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
80
|
+
ms.set_seed(seed)
|
|
81
|
+
random.seed(seed)
|
|
82
|
+
ms.set_context(deterministic="ON" if mode else "OFF")
|
|
83
|
+
os.environ['HCCL_DETERMINISTIC'] = str(mode)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class MsprobeStep(ms.train.Callback):
|
|
87
|
+
|
|
88
|
+
def __init__(self, debugger):
|
|
89
|
+
super(MsprobeStep, self).__init__()
|
|
90
|
+
self.debugger = debugger
|
|
91
|
+
|
|
92
|
+
def on_train_step_begin(self, run_context):
|
|
93
|
+
self.debugger.start()
|
|
94
|
+
|
|
95
|
+
def on_train_step_end(self, run_context):
|
|
96
|
+
self.debugger.stop()
|
|
97
|
+
self.debugger.step()
|
|
@@ -1,75 +1,62 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
ms_comparator = MSComparator()
|
|
64
|
-
ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
|
|
65
|
-
md5_compare=md5_compare, **kwargs)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def ms_graph_compare(inputs, outputs):
|
|
69
|
-
try:
|
|
70
|
-
create_directory(outputs)
|
|
71
|
-
except (CompareException, FileCheckException) as error:
|
|
72
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
73
|
-
return
|
|
74
|
-
msComparator = GraphMSComparator(inputs, outputs)
|
|
75
|
-
msComparator.compare_core()
|
|
1
|
+
import os
|
|
2
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
3
|
+
check_configuration_param, task_dumppath_get
|
|
4
|
+
from msprobe.core.common.file_utils import create_directory
|
|
5
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
6
|
+
from msprobe.mindspore.common.log import logger
|
|
7
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
8
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
9
|
+
from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
13
|
+
if kwargs.get('suffix'):
|
|
14
|
+
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
15
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
16
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
17
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
18
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
19
|
+
# get the ranks and match by order
|
|
20
|
+
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
21
|
+
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
22
|
+
if len(npu_ranks) != len(bench_ranks):
|
|
23
|
+
logger.error('The number of ranks in the two runs are different. '
|
|
24
|
+
'Unable to match the ranks. Please use another folder to compare '
|
|
25
|
+
'or use compare() api and manually match the ranks.')
|
|
26
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
27
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
28
|
+
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
29
|
+
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
30
|
+
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
31
|
+
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
32
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
33
|
+
|
|
34
|
+
dump_result_param = {
|
|
35
|
+
'npu_json_path': npu_path,
|
|
36
|
+
'bench_json_path': bench_path,
|
|
37
|
+
'stack_json_path': stack_path,
|
|
38
|
+
'is_print_compare_log': True
|
|
39
|
+
}
|
|
40
|
+
try:
|
|
41
|
+
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
42
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
|
|
43
|
+
dump_result_param.get('is_print_compare_log', True))
|
|
44
|
+
create_directory(output_path)
|
|
45
|
+
check_compare_param(dump_result_param, output_path,
|
|
46
|
+
summary_compare=summary_compare, md5_compare=md5_compare)
|
|
47
|
+
except (CompareException, FileCheckException) as error:
|
|
48
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
49
|
+
raise CompareException(error.code) from error
|
|
50
|
+
ms_comparator = MSComparator()
|
|
51
|
+
ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
52
|
+
summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def ms_graph_compare(inputs, outputs):
|
|
56
|
+
try:
|
|
57
|
+
create_directory(outputs)
|
|
58
|
+
except (CompareException, FileCheckException) as error:
|
|
59
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
60
|
+
return
|
|
61
|
+
ms_comparator = GraphMSComparator(inputs, outputs)
|
|
62
|
+
ms_comparator.compare_core()
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from msprobe.core.common.const import Const
|
|
4
|
+
from msprobe.core.common.log import logger
|
|
5
|
+
from msprobe.core.common.utils import CompareException
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Trie:
|
|
9
|
+
def __init__(self, type_name=None, has_data=False):
|
|
10
|
+
self.type_name = type_name
|
|
11
|
+
self.call_count_list = []
|
|
12
|
+
self.children = {}
|
|
13
|
+
self.has_data = has_data
|
|
14
|
+
self.node_type = None
|
|
15
|
+
|
|
16
|
+
def __repr__(self):
|
|
17
|
+
return (f"Node(type_name={self.type_name}, "
|
|
18
|
+
f"has_data={self.has_data}, call number={len(self.call_count_list)})")
|
|
19
|
+
|
|
20
|
+
def insert(self, word, word_type="func"):
|
|
21
|
+
parts = word.split(Const.SEP)
|
|
22
|
+
if len(parts) < 2:
|
|
23
|
+
logger.error('result dataframe elements can not be access.')
|
|
24
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
25
|
+
"""
|
|
26
|
+
xxx, node_name, type_name, execute_num
|
|
27
|
+
etc: Cell.network_with_loss.language_model.encoder.layers.1.attention.out_proj.RowParallelLinear.1
|
|
28
|
+
prefix_name_list: Cell.network_with_loss.language_model.encoder.layers.1.attention
|
|
29
|
+
node_name: out_proj
|
|
30
|
+
type_name: RowParallelLinear
|
|
31
|
+
call_count: 1
|
|
32
|
+
"""
|
|
33
|
+
type_name = parts[-2]
|
|
34
|
+
call_count = parts[-1]
|
|
35
|
+
node = self
|
|
36
|
+
prefix_name_list = parts[:-2]
|
|
37
|
+
|
|
38
|
+
for name in prefix_name_list:
|
|
39
|
+
if name not in node.children:
|
|
40
|
+
node.children[name] = Trie()
|
|
41
|
+
node = node.children[name]
|
|
42
|
+
if node.type_name is None:
|
|
43
|
+
node.type_name = name
|
|
44
|
+
|
|
45
|
+
node.type_name = type_name
|
|
46
|
+
node.has_data = True
|
|
47
|
+
node.call_count_list.append(call_count)
|
|
48
|
+
node.node_type = word_type
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DFSConverter:
|
|
52
|
+
def __init__(self, mapping, max_depth=100):
|
|
53
|
+
self.mapping = mapping
|
|
54
|
+
self.max_depth = max_depth
|
|
55
|
+
self.result = {}
|
|
56
|
+
|
|
57
|
+
def traverse_and_collect(self, node, path="", mapping_path="", depth=0):
|
|
58
|
+
if depth > self.max_depth:
|
|
59
|
+
logger.error("The converted data depth is too large, please check the data")
|
|
60
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
61
|
+
|
|
62
|
+
if node is None:
|
|
63
|
+
return self.result
|
|
64
|
+
|
|
65
|
+
type_name = node.type_name
|
|
66
|
+
if node.has_data:
|
|
67
|
+
for count in node.call_count_list:
|
|
68
|
+
origin_name = f"{path}.{count}" if node.node_type == "Cell" else f"{path}.{type_name}.{count}"
|
|
69
|
+
mapping_name = f"{mapping_path}.{count}" if node.node_type == "Cell" else f"{mapping_path}.{type_name}.{count}"
|
|
70
|
+
self.result[origin_name] = mapping_name
|
|
71
|
+
|
|
72
|
+
name_mapping = self.mapping.get(type_name, {})
|
|
73
|
+
|
|
74
|
+
for child_name, child_node in node.children.items():
|
|
75
|
+
new_path = f"{path}.{child_name}" if path else child_name
|
|
76
|
+
converted_name = name_mapping.get(child_name, child_name)
|
|
77
|
+
new_mapping_path = f"{mapping_path}.{converted_name}" if mapping_path else converted_name
|
|
78
|
+
self.traverse_and_collect(child_node, new_path, new_mapping_path, depth+1)
|
|
79
|
+
|
|
80
|
+
return self.result
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_mapping_list(ms_tree, mapping):
|
|
84
|
+
dfs_converter = DFSConverter(mapping)
|
|
85
|
+
ms_pt_mapping = dfs_converter.traverse_and_collect(ms_tree)
|
|
86
|
+
mapping_list = []
|
|
87
|
+
for ms_name, pt_name in ms_pt_mapping.items():
|
|
88
|
+
pt_name = re.sub(r"^Cell", "Module", pt_name)
|
|
89
|
+
mapping_list.append((ms_name, pt_name))
|
|
90
|
+
return mapping_list
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_prefix_mapping(scope_list):
|
|
94
|
+
"""layer name to layer name.class_name"""
|
|
95
|
+
layer_mapping = {}
|
|
96
|
+
for name, v in scope_list.items():
|
|
97
|
+
origin_data = v.get("origin_data")
|
|
98
|
+
if not origin_data.startswith(("Cell", "Module")):
|
|
99
|
+
continue
|
|
100
|
+
name_list = name.split(Const.SEP)
|
|
101
|
+
if len(name_list) < 2:
|
|
102
|
+
logger.error('result dataframe elements can not be access.')
|
|
103
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
104
|
+
prefix_name_list = name_list[:-2] + [name_list[-1]]
|
|
105
|
+
prefix_name = Const.SEP.join(prefix_name_list)
|
|
106
|
+
layer_mapping[prefix_name] = name
|
|
107
|
+
return layer_mapping
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_layer_mapping(ms_scope_list, pt_scope_list, mapping):
|
|
111
|
+
# 1. get layer prefix to full name mapping
|
|
112
|
+
# ect: Cell.network_with_loss.language_model.embedding.3 : Cell.network_with_loss.language_model.embedding.Embedding.3
|
|
113
|
+
ms_prefix2fullname = get_prefix_mapping(ms_scope_list)
|
|
114
|
+
# 2. build trie tree
|
|
115
|
+
ms_tree = Trie(type_name="Cell")
|
|
116
|
+
for k, r in ms_scope_list.items():
|
|
117
|
+
origin_data_name = r.get('origin_data')
|
|
118
|
+
data_type = origin_data_name.split(Const.SEP)[0]
|
|
119
|
+
ms_tree.insert(k, data_type)
|
|
120
|
+
msname2ptname = get_mapping_list(ms_tree, mapping)
|
|
121
|
+
# 3. get pt layer prefix to full name mapping
|
|
122
|
+
# ect: Module.network_with_loss.language_model.embedding.3 : Module.network_with_loss.language_model.embedding.Embedding.3
|
|
123
|
+
pt_prefix2fullname = get_prefix_mapping(pt_scope_list)
|
|
124
|
+
|
|
125
|
+
final_mapping = []
|
|
126
|
+
for ms_name, pt_name in msname2ptname:
|
|
127
|
+
final_ms_name = ms_name
|
|
128
|
+
final_pt_name = pt_name
|
|
129
|
+
# cell
|
|
130
|
+
if ms_name in ms_prefix2fullname:
|
|
131
|
+
final_ms_name = ms_prefix2fullname.get(ms_name)
|
|
132
|
+
final_pt_name = pt_prefix2fullname.get(pt_name, None)
|
|
133
|
+
# func
|
|
134
|
+
elif final_ms_name in ms_scope_list:
|
|
135
|
+
final_ms_name = ms_scope_list.get(ms_name)['origin_data']
|
|
136
|
+
# remove forward/backward
|
|
137
|
+
final_ms_name = Const.SEP.join(final_ms_name.split(Const.SEP)[:-1])
|
|
138
|
+
final_pt_name = pt_scope_list.get(pt_name, None)
|
|
139
|
+
if final_pt_name:
|
|
140
|
+
final_pt_name = final_pt_name['origin_data']
|
|
141
|
+
final_pt_name = Const.SEP.join(final_pt_name.split(Const.SEP)[:-1])
|
|
142
|
+
else:
|
|
143
|
+
continue
|
|
144
|
+
final_mapping.append((final_ms_name, final_pt_name))
|
|
145
|
+
|
|
146
|
+
return final_mapping
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from msprobe.core.common.const import Const
|
|
2
|
+
from msprobe.core.common.log import logger
|
|
3
|
+
|
|
4
|
+
def find_regard_scope(lines, start_sign, end_sign):
|
|
5
|
+
# 找出 start_pos 和 end_pos
|
|
6
|
+
start_pos = end_pos = -1
|
|
7
|
+
for idx, ii in enumerate(lines):
|
|
8
|
+
if start_sign in ii:
|
|
9
|
+
start_pos = idx
|
|
10
|
+
elif end_sign in ii:
|
|
11
|
+
end_pos = idx
|
|
12
|
+
break
|
|
13
|
+
return start_pos, end_pos
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def find_stack_func_list(lines):
|
|
17
|
+
res_list = []
|
|
18
|
+
# 过滤和处理 regard_scope
|
|
19
|
+
for line in lines:
|
|
20
|
+
ele_list = line.split(',')
|
|
21
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
22
|
+
if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
func_ele = ele_list[Const.STACK_FUNC_INDEX]
|
|
26
|
+
if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
|
|
27
|
+
continue
|
|
28
|
+
|
|
29
|
+
in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
|
|
30
|
+
|
|
31
|
+
res_list.append(in_func_name)
|
|
32
|
+
# 反转res_list并生成final_res
|
|
33
|
+
reversed_list = res_list[::-1]
|
|
34
|
+
return reversed_list
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_duplicated_name(components):
|
|
38
|
+
duplicated_components = components
|
|
39
|
+
if len(components) < 3 or components[Const.CONSTRUCT_NAME_INDEX].isdigit():
|
|
40
|
+
logger.warning("key in construct.json is shorter than 3 parts or not name valid.")
|
|
41
|
+
else:
|
|
42
|
+
# 重复name,如Functional.add.add.X ward
|
|
43
|
+
duplicated_components = components[:Const.CONSTRUCT_NAME_INDEX + 1] + components[Const.CONSTRUCT_NAME_INDEX:]
|
|
44
|
+
return duplicated_components
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def modify_mapping_with_stack(stack, construct):
|
|
48
|
+
if not stack or not construct:
|
|
49
|
+
return {}
|
|
50
|
+
|
|
51
|
+
# 是否是mindspore的数据结构
|
|
52
|
+
is_ms = any("Cell" in ii for ii in construct)
|
|
53
|
+
# 调整后的mapping结构
|
|
54
|
+
final_pres = {}
|
|
55
|
+
# 查看归属关系
|
|
56
|
+
for key in construct:
|
|
57
|
+
key_components = key.split(Const.SEP)
|
|
58
|
+
code_list = stack.get(key, None)
|
|
59
|
+
parent_node = construct.get(key, None)
|
|
60
|
+
# 名称如果非标准开头,转为标准开头
|
|
61
|
+
if not key.startswith(("Module", "Cell")):
|
|
62
|
+
# 如果没有拿到父属scope name,默认顶级域名为Module或Cell
|
|
63
|
+
if not parent_node:
|
|
64
|
+
# 将节点名字转为标准的Module或Cell
|
|
65
|
+
key_components[0] = "Cell" if is_ms else "Module"
|
|
66
|
+
# 重复该节点的名字作为类型 如add.add add在-3位置
|
|
67
|
+
duplicated_components = get_duplicated_name(key_components)
|
|
68
|
+
modified_key = Const.SEP.join(duplicated_components)
|
|
69
|
+
|
|
70
|
+
modified_key = modified_key.replace(".forward", "").replace(".backward", "")
|
|
71
|
+
final_pres[modified_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: None, Const.STACK: None}
|
|
72
|
+
continue
|
|
73
|
+
parent = parent_node.split(Const.SEP)
|
|
74
|
+
if len(parent) < 4:
|
|
75
|
+
logger.info(f"Parent name in construct.json is not valid")
|
|
76
|
+
continue
|
|
77
|
+
parent_idx = Const.NAME_FIRST_POSSIBLE_INDEX if not \
|
|
78
|
+
parent[Const.NAME_FIRST_POSSIBLE_INDEX].isdigit() else Const.NAME_SECOND_POSSIBLE_INDEX
|
|
79
|
+
parent_name = parent[parent_idx]
|
|
80
|
+
|
|
81
|
+
if code_list:
|
|
82
|
+
# {name}.Class.count_number.X ward Or {name}.Class.count_number.X ward.ele_number
|
|
83
|
+
if parent_name.endswith('s'):
|
|
84
|
+
parent_name = parent_name[:-1]
|
|
85
|
+
if len(key_components) < 3:
|
|
86
|
+
logger.info("The length of key in construct is less than 3, please check")
|
|
87
|
+
continue
|
|
88
|
+
# {name}.count_number.X ward
|
|
89
|
+
func_name = key_components[-3]
|
|
90
|
+
start_pos, end_pos = find_regard_scope(code_list, func_name, parent_name)
|
|
91
|
+
|
|
92
|
+
# 获取指定范围的代码
|
|
93
|
+
regard_scope = code_list[start_pos:end_pos]
|
|
94
|
+
|
|
95
|
+
func_stack_list = find_stack_func_list(regard_scope)
|
|
96
|
+
else:
|
|
97
|
+
func_stack_list = []
|
|
98
|
+
# 组合逻辑:parent的节点名(到节点名字为止)加上调用栈名[reversed_list]加上原来key重复key的节点名[key_components[1:-2] + key_components[-3:]]
|
|
99
|
+
final_res_key = Const.SEP.join(parent[:parent_idx + 1] + func_stack_list +
|
|
100
|
+
key_components[1:Const.CONSTRUCT_NAME_INDEX + 1] + key_components[Const.CONSTRUCT_NAME_INDEX:])
|
|
101
|
+
final_res_key = final_res_key.strip(".forward").strip(".backward")
|
|
102
|
+
else:
|
|
103
|
+
final_res_key = Const.SEP.join(key_components[:-2] + [key_components[-1]])
|
|
104
|
+
func_stack_list = []
|
|
105
|
+
final_pres[final_res_key] = {Const.ORIGIN_DATA: key, Const.SCOPE: parent_node,
|
|
106
|
+
Const.STACK: Const.SEP.join(func_stack_list) if func_stack_list else None}
|
|
107
|
+
return final_pres
|