mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import logging
|
|
18
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ParseException(Exception):
|
|
22
|
+
|
|
23
|
+
PARSE_INVALID_PATH_ERROR = 0
|
|
24
|
+
PARSE_NO_FILE_ERROR = 1
|
|
25
|
+
PARSE_NO_MODULE_ERROR = 2
|
|
26
|
+
PARSE_INVALID_DATA_ERROR = 3
|
|
27
|
+
PARSE_INVALID_FILE_FORMAT_ERROR = 4
|
|
28
|
+
PARSE_UNICODE_ERROR = 5
|
|
29
|
+
PARSE_JSONDECODE_ERROR = 6
|
|
30
|
+
PARSE_MSACCUCMP_ERROR = 7
|
|
31
|
+
PARSE_LOAD_NPY_ERROR = 8
|
|
32
|
+
PARSE_INVALID_PARAM_ERROR = 9
|
|
33
|
+
|
|
34
|
+
def __init__(self, code, error_info=""):
|
|
35
|
+
super(ParseException, self).__init__()
|
|
36
|
+
self.error_info = error_info
|
|
37
|
+
self.code = code
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def catch_exception(func):
|
|
41
|
+
def inner(*args, **kwargs):
|
|
42
|
+
log = logging.getLogger()
|
|
43
|
+
line = args[-1] if len(args) == 2 else ""
|
|
44
|
+
result = None
|
|
45
|
+
try:
|
|
46
|
+
result = func(*args, **kwargs)
|
|
47
|
+
except OSError:
|
|
48
|
+
log.error("%s: command not found" % line)
|
|
49
|
+
except ParseException:
|
|
50
|
+
log.error("Command execution failed")
|
|
51
|
+
except FileCheckException:
|
|
52
|
+
log.error("Command execution failed")
|
|
53
|
+
return result
|
|
54
|
+
return inner
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import argparse
|
|
18
|
+
import os
|
|
19
|
+
from collections import namedtuple
|
|
20
|
+
|
|
21
|
+
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
22
|
+
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
23
|
+
from msprobe.pytorch.parse_tool.lib.compare import Compare
|
|
24
|
+
from msprobe.pytorch.parse_tool.lib.visualization import Visualization
|
|
25
|
+
from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ParseTool:
|
|
29
|
+
def __init__(self):
|
|
30
|
+
self.util = Util()
|
|
31
|
+
self.compare = Compare()
|
|
32
|
+
self.visual = Visualization()
|
|
33
|
+
|
|
34
|
+
@catch_exception
|
|
35
|
+
def prepare(self):
|
|
36
|
+
self.util.create_dir(Const.DATA_ROOT_DIR)
|
|
37
|
+
|
|
38
|
+
@catch_exception
|
|
39
|
+
def do_vector_compare(self, args):
|
|
40
|
+
if not args.output_path:
|
|
41
|
+
result_dir = os.path.join(Const.COMPARE_DIR)
|
|
42
|
+
else:
|
|
43
|
+
result_dir = args.output_path
|
|
44
|
+
my_dump_path = args.my_dump_path
|
|
45
|
+
golden_dump_path = args.golden_dump_path
|
|
46
|
+
if not os.path.isdir(my_dump_path) or not os.path.isdir(golden_dump_path):
|
|
47
|
+
self.util.log.error("Please enter a directory not a file")
|
|
48
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
49
|
+
msaccucmp_path = self.util.path_strip(args.msaccucmp_path) if args.msaccucmp_path else Const.MS_ACCU_CMP_PATH
|
|
50
|
+
self.util.check_path_valid(msaccucmp_path)
|
|
51
|
+
self.util.check_executable_file(msaccucmp_path)
|
|
52
|
+
self.compare.npu_vs_npu_compare(my_dump_path, golden_dump_path, result_dir, msaccucmp_path)
|
|
53
|
+
|
|
54
|
+
@catch_exception
|
|
55
|
+
def do_convert_dump(self, argv=None):
|
|
56
|
+
parser = argparse.ArgumentParser()
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
'-n', '--name', dest='path', default=None, required=True, help='dump file or dump file directory')
|
|
59
|
+
parser.add_argument(
|
|
60
|
+
'-f', '--format', dest='format', default=None, required=False, help='target format')
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
'-out', '--output_path', dest='output_path', required=False, default=None, help='output path')
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"-cmp_path", "--msaccucmp_path", dest="msaccucmp_path", default=None,
|
|
65
|
+
help="<Optional> the msaccucmp.py file path", required=False)
|
|
66
|
+
args = parser.parse_args(argv)
|
|
67
|
+
self.util.check_path_valid(args.path)
|
|
68
|
+
self.util.check_files_in_path(args.path)
|
|
69
|
+
msaccucmp_path = self.util.path_strip(args.msaccucmp_path) if args.msaccucmp_path else Const.MS_ACCU_CMP_PATH
|
|
70
|
+
self.util.check_path_valid(msaccucmp_path)
|
|
71
|
+
self.util.check_executable_file(msaccucmp_path)
|
|
72
|
+
if args.format:
|
|
73
|
+
self.util.check_str_param(args.format)
|
|
74
|
+
self.compare.convert_dump_to_npy(args.path, args.format, args.output_path, msaccucmp_path)
|
|
75
|
+
|
|
76
|
+
@catch_exception
|
|
77
|
+
def do_print_data(self, argv=None):
|
|
78
|
+
"""print tensor data"""
|
|
79
|
+
parser = argparse.ArgumentParser()
|
|
80
|
+
parser.add_argument('-n', '--name', dest='path', default=None, required=True, help='File name')
|
|
81
|
+
args = parser.parse_args(argv)
|
|
82
|
+
self.visual.print_npy_data(args.path)
|
|
83
|
+
|
|
84
|
+
@catch_exception
|
|
85
|
+
def do_parse_pkl(self, argv=None):
|
|
86
|
+
parser = argparse.ArgumentParser()
|
|
87
|
+
parser.add_argument(
|
|
88
|
+
'-f', '--file', dest='file_name', default=None, required=True, help='PKL file path')
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
'-n', '--name', dest='api_name', default=None, required=True, help='API name')
|
|
91
|
+
args = parser.parse_args(argv)
|
|
92
|
+
self.visual.parse_pkl(args.file_name, args.api_name)
|
|
93
|
+
|
|
94
|
+
@catch_exception
|
|
95
|
+
def do_compare_data(self, argv):
|
|
96
|
+
"""compare two tensor"""
|
|
97
|
+
parser = argparse.ArgumentParser()
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"-m", "--my_dump_path", dest="my_dump_path", default=None,
|
|
100
|
+
help="<Required> my dump path, the data compared with golden data",
|
|
101
|
+
required=True
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"-g", "--golden_dump_path", dest="golden_dump_path", default=None,
|
|
105
|
+
help="<Required> the golden dump data path",
|
|
106
|
+
required=True
|
|
107
|
+
)
|
|
108
|
+
parser.add_argument('-p', '--print', dest='count', default=20, type=int, help='print err data num')
|
|
109
|
+
parser.add_argument('-s', '--save', dest='save', action='store_true', help='save data in txt format')
|
|
110
|
+
parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol')
|
|
111
|
+
parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol')
|
|
112
|
+
args = parser.parse_args(argv)
|
|
113
|
+
self.util.check_path_valid(args.my_dump_path)
|
|
114
|
+
self.util.check_path_valid(args.golden_dump_path)
|
|
115
|
+
self.util.check_path_format(args.my_dump_path, Const.NPY_SUFFIX)
|
|
116
|
+
self.util.check_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
|
|
117
|
+
compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
|
|
118
|
+
compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20)
|
|
119
|
+
res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count)
|
|
120
|
+
self.compare.compare_data(res)
|
|
121
|
+
|
|
122
|
+
@catch_exception
|
|
123
|
+
def do_compare_converted_dir(self, args):
|
|
124
|
+
"""compare two dir"""
|
|
125
|
+
my_dump_dir = self.util.path_strip(args.my_dump_path)
|
|
126
|
+
golden_dump_dir = self.util.path_strip(args.golden_dump_path)
|
|
127
|
+
if my_dump_dir == golden_dump_dir:
|
|
128
|
+
self.util.log.error("My directory path and golden directory path is same. Please check parameter"
|
|
129
|
+
" '-m' and '-g'.")
|
|
130
|
+
raise ParseException("My directory path and golden directory path is same.")
|
|
131
|
+
output_path = self.util.path_strip(args.output_path) if args.output_path else Const.BATCH_COMPARE_DIR
|
|
132
|
+
if not os.path.isdir(output_path):
|
|
133
|
+
os.makedirs(output_path, mode=0o750)
|
|
134
|
+
self.compare.compare_converted_dir(my_dump_dir, golden_dump_dir, output_path)
|
|
135
|
+
|
|
136
|
+
@catch_exception
|
|
137
|
+
def do_convert_api_dir(self, argv=None):
|
|
138
|
+
parser = argparse.ArgumentParser()
|
|
139
|
+
parser.add_argument(
|
|
140
|
+
"-m", "--my_dump_path", dest="my_dump_path", default=None,
|
|
141
|
+
help="<Required> my dump path, the data need to convert to npy files.",
|
|
142
|
+
required=True
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument(
|
|
145
|
+
'-out', '--output_path', dest='output_path', required=False, default=None, help='output path')
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"-asc", "--msaccucmp_path", dest="msaccucmp_path", default=None,
|
|
148
|
+
help="<Optional> the msaccucmp.py file path", required=False)
|
|
149
|
+
args = parser.parse_args(argv)
|
|
150
|
+
self.util.check_path_valid(args.my_dump_path)
|
|
151
|
+
self.util.check_files_in_path(args.my_dump_path)
|
|
152
|
+
output_path = self.util.path_strip(args.output_path) if args.output_path else \
|
|
153
|
+
os.path.join(Const.BATCH_DUMP_CONVERT_DIR, self.util.localtime_str())
|
|
154
|
+
msaccucmp_path = self.util.path_strip(
|
|
155
|
+
args.msaccucmp_path) if args.msaccucmp_path else Const.MS_ACCU_CMP_PATH
|
|
156
|
+
self.util.check_path_valid(msaccucmp_path)
|
|
157
|
+
self.util.check_executable_file(msaccucmp_path)
|
|
158
|
+
self.compare.convert_api_dir_to_npy(args.my_dump_path, None, output_path, msaccucmp_path)
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import io
|
|
20
|
+
import re
|
|
21
|
+
import sys
|
|
22
|
+
import subprocess
|
|
23
|
+
import hashlib
|
|
24
|
+
import csv
|
|
25
|
+
import time
|
|
26
|
+
import numpy as np
|
|
27
|
+
from collections import namedtuple
|
|
28
|
+
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
29
|
+
from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
|
|
30
|
+
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
31
|
+
from msprobe.core.common.file_check import change_mode, check_other_user_writable,\
|
|
32
|
+
check_path_executable, check_path_owner_consistent
|
|
33
|
+
from msprobe.core.common.const import FileCheckConst
|
|
34
|
+
from msprobe.core.common.file_check import FileOpen
|
|
35
|
+
from msprobe.core.common.utils import check_file_or_directory_path
|
|
36
|
+
from msprobe.pytorch.common.log import logger
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from rich.traceback import install
|
|
41
|
+
from rich.panel import Panel
|
|
42
|
+
from rich.table import Table
|
|
43
|
+
from rich import print as rich_print
|
|
44
|
+
from rich.columns import Columns
|
|
45
|
+
|
|
46
|
+
install()
|
|
47
|
+
except ImportError as err:
|
|
48
|
+
install = None
|
|
49
|
+
Panel = None
|
|
50
|
+
Table = None
|
|
51
|
+
Columns = None
|
|
52
|
+
rich_print = None
|
|
53
|
+
logger.warning(
|
|
54
|
+
"Failed to import rich, Some features may not be available. Please run 'pip install rich' to fix it.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Util:
|
|
58
|
+
def __init__(self):
|
|
59
|
+
self.ms_accu_cmp = None
|
|
60
|
+
logging.basicConfig(
|
|
61
|
+
level=Const.LOG_LEVEL,
|
|
62
|
+
format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s",
|
|
63
|
+
datefmt="%Y-%m-%d %H:%M:%S"
|
|
64
|
+
)
|
|
65
|
+
self.log = logging.getLogger()
|
|
66
|
+
self.python = sys.executable
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def print(content):
|
|
70
|
+
rich_print(content)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def path_strip(path):
|
|
74
|
+
return path.strip("'").strip('"')
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def check_executable_file(path):
|
|
78
|
+
check_path_owner_consistent(path)
|
|
79
|
+
check_other_user_writable(path)
|
|
80
|
+
check_path_executable(path)
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def get_subdir_count(self, directory):
|
|
84
|
+
subdir_count = 0
|
|
85
|
+
for _, dirs, _ in os.walk(directory):
|
|
86
|
+
subdir_count += len(dirs)
|
|
87
|
+
break
|
|
88
|
+
return subdir_count
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def get_subfiles_count(self, directory):
|
|
92
|
+
file_count = 0
|
|
93
|
+
for _, _, files in os.walk(directory):
|
|
94
|
+
file_count += len(files)
|
|
95
|
+
return file_count
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def get_sorted_subdirectories_names(self, directory):
|
|
99
|
+
subdirectories = []
|
|
100
|
+
for item in os.listdir(directory):
|
|
101
|
+
item_path = os.path.join(directory, item)
|
|
102
|
+
if os.path.isdir(item_path):
|
|
103
|
+
subdirectories.append(item)
|
|
104
|
+
return sorted(subdirectories)
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def get_sorted_files_names(self, directory):
|
|
108
|
+
files = []
|
|
109
|
+
for item in os.listdir(directory):
|
|
110
|
+
item_path = os.path.join(directory, item)
|
|
111
|
+
if os.path.isfile(item_path):
|
|
112
|
+
files.append(item)
|
|
113
|
+
return sorted(files)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def check_npy_files_valid_in_dir(self, dir_path):
|
|
117
|
+
for file_name in os.listdir(dir_path):
|
|
118
|
+
file_path = os.path.join(dir_path, file_name)
|
|
119
|
+
check_file_or_directory_path(file_path)
|
|
120
|
+
_, file_extension = os.path.splitext(file_path)
|
|
121
|
+
if not file_extension == '.npy':
|
|
122
|
+
return False
|
|
123
|
+
return True
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def get_md5_for_numpy(self, obj):
|
|
127
|
+
np_bytes = obj.tobytes()
|
|
128
|
+
md5_hash = hashlib.md5(np_bytes)
|
|
129
|
+
return md5_hash.hexdigest()
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def write_csv(self, data, filepath):
|
|
133
|
+
need_change_mode = False
|
|
134
|
+
if not os.path.exists(filepath):
|
|
135
|
+
need_change_mode = True
|
|
136
|
+
with FileOpen(filepath, 'a') as f:
|
|
137
|
+
writer = csv.writer(f)
|
|
138
|
+
writer.writerows(data)
|
|
139
|
+
if need_change_mode:
|
|
140
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def deal_with_dir_or_file_inconsistency(self, output_path):
|
|
144
|
+
if os.path.exists(output_path):
|
|
145
|
+
os.remove(output_path)
|
|
146
|
+
raise ParseException("Inconsistent directory structure or file.")
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def deal_with_value_if_has_zero(self, data):
|
|
150
|
+
if data.dtype in Const.FLOAT_TYPE:
|
|
151
|
+
zero_mask = (data == 0)
|
|
152
|
+
# 给0的地方加上eps防止除0
|
|
153
|
+
data[zero_mask] += np.finfo(data.dtype).eps
|
|
154
|
+
else:
|
|
155
|
+
# int type + float eps 会报错,所以这里要强转
|
|
156
|
+
data = data.astype(float)
|
|
157
|
+
zero_mask = (data == 0)
|
|
158
|
+
data[zero_mask] += np.finfo(float).eps
|
|
159
|
+
return data
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def dir_contains_only(self, path, endfix):
|
|
163
|
+
for _, _, files in os.walk(path):
|
|
164
|
+
for file in files:
|
|
165
|
+
if not file.endswith(endfix):
|
|
166
|
+
return False
|
|
167
|
+
return True
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def localtime_str(self):
|
|
171
|
+
return time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def change_filemode_safe(self, path):
|
|
175
|
+
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def _gen_npu_dump_convert_file_info(name, match, dir_path):
|
|
179
|
+
return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2),
|
|
180
|
+
op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3],
|
|
181
|
+
anchor_idx=int(match.groups()[-2]))
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def _gen_numpy_file_info(name, math, dir_path):
|
|
185
|
+
return FileDesc(name, dir_path)
|
|
186
|
+
|
|
187
|
+
def execute_command(self, cmd):
|
|
188
|
+
if not cmd:
|
|
189
|
+
self.log.error("Commond is None")
|
|
190
|
+
return -1
|
|
191
|
+
self.log.debug("[RUN CMD]: %s", cmd)
|
|
192
|
+
cmd = cmd.split(" ")
|
|
193
|
+
complete_process = subprocess.run(cmd, shell=False)
|
|
194
|
+
return complete_process.returncode
|
|
195
|
+
|
|
196
|
+
def print_panel(self, content, title='', fit=True):
|
|
197
|
+
if not Panel:
|
|
198
|
+
self.print(content)
|
|
199
|
+
return
|
|
200
|
+
if fit:
|
|
201
|
+
self.print(Panel.fit(content, title=title))
|
|
202
|
+
else:
|
|
203
|
+
self.print(Panel(content, title=title))
|
|
204
|
+
|
|
205
|
+
def check_msaccucmp(self, target_file):
|
|
206
|
+
if os.path.split(target_file)[-1] != Const.MS_ACCU_CMP_FILE_NAME:
|
|
207
|
+
self.log.error(
|
|
208
|
+
"Check msaccucmp failed in dir %s. This is not a correct msaccucmp file" % target_file)
|
|
209
|
+
raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
|
|
210
|
+
result = subprocess.run(
|
|
211
|
+
[self.python, target_file, "--help"], stdout=subprocess.PIPE)
|
|
212
|
+
if result.returncode == 0:
|
|
213
|
+
self.log.info("Check [%s] success.", target_file)
|
|
214
|
+
else:
|
|
215
|
+
self.log.error("Check msaccucmp failed in dir %s" % target_file)
|
|
216
|
+
self.log.error("Please specify a valid msaccucmp.py path or install the cann package")
|
|
217
|
+
raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
|
|
218
|
+
return target_file
|
|
219
|
+
|
|
220
|
+
def create_dir(self, path):
|
|
221
|
+
path = self.path_strip(path)
|
|
222
|
+
if os.path.exists(path):
|
|
223
|
+
return
|
|
224
|
+
self.check_path_name(path)
|
|
225
|
+
try:
|
|
226
|
+
os.makedirs(path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
|
|
227
|
+
except OSError as e:
|
|
228
|
+
self.log.error("Failed to create %s.", path)
|
|
229
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e
|
|
230
|
+
|
|
231
|
+
def gen_npy_info_txt(self, source_data):
|
|
232
|
+
(shape, dtype, max_data, min_data, mean) = \
|
|
233
|
+
self.npy_info(source_data)
|
|
234
|
+
return \
|
|
235
|
+
'[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean)
|
|
236
|
+
|
|
237
|
+
def save_npy_to_txt(self, data, dst_file='', align=0):
|
|
238
|
+
if os.path.exists(dst_file):
|
|
239
|
+
self.log.info("Dst file %s exists, will not save new one.", dst_file)
|
|
240
|
+
return
|
|
241
|
+
shape = data.shape
|
|
242
|
+
data = data.flatten()
|
|
243
|
+
if align == 0:
|
|
244
|
+
align = 1 if len(shape) == 0 else shape[-1]
|
|
245
|
+
elif data.size % align != 0:
|
|
246
|
+
pad_array = np.zeros((align - data.size % align,))
|
|
247
|
+
data = np.append(data, pad_array)
|
|
248
|
+
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
249
|
+
change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
250
|
+
|
|
251
|
+
def list_convert_files(self, path, external_pattern=""):
|
|
252
|
+
return self.list_file_with_pattern(
|
|
253
|
+
path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def list_numpy_files(self, path, extern_pattern=''):
|
|
257
|
+
return self.list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern,
|
|
258
|
+
self._gen_numpy_file_info)
|
|
259
|
+
|
|
260
|
+
def create_columns(self, content):
|
|
261
|
+
if not Columns:
|
|
262
|
+
self.log.error("No module named rich, please install it")
|
|
263
|
+
raise ParseException(ParseException.PARSE_NO_MODULE_ERROR)
|
|
264
|
+
return Columns(content)
|
|
265
|
+
|
|
266
|
+
def create_table(self, title, columns):
|
|
267
|
+
if not Table:
|
|
268
|
+
self.log.error("No module named rich, please install it and restart parse tool")
|
|
269
|
+
raise ParseException(ParseException.PARSE_NO_MODULE_ERROR)
|
|
270
|
+
table = Table(title=title)
|
|
271
|
+
for column_name in columns:
|
|
272
|
+
table.add_column(column_name, overflow='fold')
|
|
273
|
+
return table
|
|
274
|
+
|
|
275
|
+
def check_path_valid(self, path):
|
|
276
|
+
path = self.path_strip(path)
|
|
277
|
+
if not path or not os.path.exists(path):
|
|
278
|
+
self.log.error("The path %s does not exist." % path)
|
|
279
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
280
|
+
if os.path.islink(path):
|
|
281
|
+
self.log.error('The file path {} is a soft link.'.format(path))
|
|
282
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
283
|
+
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
284
|
+
Const.FILE_NAME_LENGTH:
|
|
285
|
+
self.log.error('The file path length exceeds limit.')
|
|
286
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
287
|
+
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
288
|
+
self.log.error('The file path {} contains special characters.'.format(path))
|
|
289
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
290
|
+
if os.path.isfile(path):
|
|
291
|
+
file_size = os.path.getsize(path)
|
|
292
|
+
if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
|
|
293
|
+
self.log.error('The file {} size is greater than 1GB.'.format(path))
|
|
294
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
295
|
+
if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB:
|
|
296
|
+
self.log.error('The file {} size is greater than 10GB.'.format(path))
|
|
297
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
298
|
+
return True
|
|
299
|
+
|
|
300
|
+
def check_files_in_path(self, path):
|
|
301
|
+
if os.path.isdir(path) and len(os.listdir(path)) == 0:
|
|
302
|
+
self.log.error("No files in %s." % path)
|
|
303
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
304
|
+
|
|
305
|
+
def npy_info(self, source_data):
|
|
306
|
+
if isinstance(source_data, np.ndarray):
|
|
307
|
+
data = source_data
|
|
308
|
+
else:
|
|
309
|
+
self.log.error("Invalid data, data is not ndarray")
|
|
310
|
+
raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
|
|
311
|
+
if data.dtype == 'object':
|
|
312
|
+
self.log.error("Invalid data, data is object.")
|
|
313
|
+
raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
|
|
314
|
+
if np.size(data) == 0:
|
|
315
|
+
self.log.error("Invalid data, data is empty")
|
|
316
|
+
raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
|
|
317
|
+
npu_info_result = namedtuple('npu_info_result', ['shape', 'dtype', 'max', 'min', 'mean'])
|
|
318
|
+
res = npu_info_result(data.shape, data.dtype, data.max(), data.min(), data.mean())
|
|
319
|
+
return res
|
|
320
|
+
|
|
321
|
+
def list_file_with_pattern(self, path, pattern, extern_pattern, gen_info_func):
|
|
322
|
+
self.check_path_valid(path)
|
|
323
|
+
file_list = {}
|
|
324
|
+
re_pattern = re.compile(pattern)
|
|
325
|
+
for dir_path, _, file_names in os.walk(path, followlinks=True):
|
|
326
|
+
for name in file_names:
|
|
327
|
+
match = re_pattern.match(name)
|
|
328
|
+
if not match:
|
|
329
|
+
continue
|
|
330
|
+
if extern_pattern != '' and not re.match(extern_pattern, name):
|
|
331
|
+
continue
|
|
332
|
+
file_list[name] = gen_info_func(name, match, dir_path)
|
|
333
|
+
return file_list
|
|
334
|
+
|
|
335
|
+
def check_path_format(self, path, suffix):
|
|
336
|
+
if os.path.isfile(path):
|
|
337
|
+
if not path.endswith(suffix):
|
|
338
|
+
self.log.error("%s is not a %s file." % (path, suffix))
|
|
339
|
+
raise ParseException(ParseException.PARSE_INVALID_FILE_FORMAT_ERROR)
|
|
340
|
+
elif os.path.isdir(path):
|
|
341
|
+
self.log.error("Please specify a single file path")
|
|
342
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
343
|
+
else:
|
|
344
|
+
self.log.error("The file path %s is invalid" % path)
|
|
345
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
346
|
+
|
|
347
|
+
def check_path_name(self, path):
|
|
348
|
+
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
349
|
+
Const.FILE_NAME_LENGTH:
|
|
350
|
+
self.log.error('The file path length exceeds limit.')
|
|
351
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
352
|
+
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
353
|
+
self.log.error('The file path {} contains special characters.'.format(path))
|
|
354
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
355
|
+
|
|
356
|
+
def check_str_param(self, param):
|
|
357
|
+
if len(param) > Const.FILE_NAME_LENGTH:
|
|
358
|
+
self.log.error('The parameter length exceeds limit')
|
|
359
|
+
raise ParseException(ParseException.PARSE_INVALID_PARAM_ERROR)
|
|
360
|
+
if not re.match(Const.FILE_PATTERN, param):
|
|
361
|
+
self.log.error('The parameter {} contains special characters.'.format(param))
|
|
362
|
+
raise ParseException(ParseException.PARSE_INVALID_PARAM_ERROR)
|
|
363
|
+
|
|
364
|
+
def is_subdir_count_equal(self, dir1, dir2):
|
|
365
|
+
dir1_count = self.get_subdir_count(dir1)
|
|
366
|
+
dir2_count = self.get_subdir_count(dir2)
|
|
367
|
+
return dir1_count == dir2_count
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import json
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
21
|
+
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
22
|
+
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
23
|
+
from msprobe.core.common.file_check import FileOpen
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Visualization:
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.util = Util()
|
|
29
|
+
|
|
30
|
+
def print_npy_summary(self, target_file):
|
|
31
|
+
try:
|
|
32
|
+
np_data = np.load(target_file, allow_pickle=True)
|
|
33
|
+
except UnicodeError as e:
|
|
34
|
+
self.util.log.error("%s %s" % ("UnicodeError", str(e)))
|
|
35
|
+
self.util.log.warning("Please check the npy file")
|
|
36
|
+
raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
|
|
37
|
+
table = self.util.create_table('', ['Index', 'Data'])
|
|
38
|
+
flatten_data = np_data.flatten()
|
|
39
|
+
tablesize = 8
|
|
40
|
+
for i in range(min(16, int(np.ceil(flatten_data.size / tablesize)))):
|
|
41
|
+
last_idx = min(flatten_data.size, i * tablesize + tablesize)
|
|
42
|
+
table.add_row(str(i * tablesize), ' '.join(flatten_data[i * tablesize: last_idx].astype('str').tolist()))
|
|
43
|
+
summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file,
|
|
44
|
+
"TextFile: %s.txt" % target_file]
|
|
45
|
+
self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file)
|
|
46
|
+
self.util.save_npy_to_txt(np_data, target_file + ".txt")
|
|
47
|
+
|
|
48
|
+
def print_npy_data(self, file_name):
|
|
49
|
+
file_name = self.util.path_strip(file_name)
|
|
50
|
+
self.util.check_path_valid(file_name)
|
|
51
|
+
self.util.check_path_format(file_name, Const.NPY_SUFFIX)
|
|
52
|
+
return self.print_npy_summary(file_name)
|
|
53
|
+
|
|
54
|
+
def parse_pkl(self, path, api_name):
|
|
55
|
+
path = self.util.path_strip(path)
|
|
56
|
+
self.util.check_path_valid(path)
|
|
57
|
+
self.util.check_path_format(path, Const.PKL_SUFFIX)
|
|
58
|
+
self.util.check_str_param(api_name)
|
|
59
|
+
with FileOpen(path, "r") as pkl_handle:
|
|
60
|
+
title_printed = False
|
|
61
|
+
while True:
|
|
62
|
+
pkl_line = pkl_handle.readline()
|
|
63
|
+
if pkl_line == '\n':
|
|
64
|
+
continue
|
|
65
|
+
if len(pkl_line) == 0:
|
|
66
|
+
break
|
|
67
|
+
try:
|
|
68
|
+
msg = json.loads(pkl_line)
|
|
69
|
+
except json.JSONDecodeError as e:
|
|
70
|
+
self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line))
|
|
71
|
+
self.util.log.warning("Please check the pkl file")
|
|
72
|
+
raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e
|
|
73
|
+
info_prefix = msg[0]
|
|
74
|
+
if not info_prefix.startswith(api_name):
|
|
75
|
+
continue
|
|
76
|
+
if info_prefix.find("stack_info") != -1 and len(msg) == 2:
|
|
77
|
+
self.util.log.info("\nTrace back({}):".format(msg[0]))
|
|
78
|
+
if msg[1] and len(msg[1]) > 4:
|
|
79
|
+
for item in reversed(msg[1]):
|
|
80
|
+
self.util.log.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2]))
|
|
81
|
+
self.util.log.info(" {}".format(item[3]))
|
|
82
|
+
continue
|
|
83
|
+
if len(msg) > 5 and len(msg[5]) >= 3:
|
|
84
|
+
summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \
|
|
85
|
+
.format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2])
|
|
86
|
+
if not title_printed:
|
|
87
|
+
self.util.log.info("\nStatistic Info:")
|
|
88
|
+
title_printed = True
|
|
89
|
+
self.util.log.info(summery_info)
|
|
90
|
+
pkl_handle.close()
|