mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- msprobe/README.md +46 -16
- msprobe/__init__.py +16 -1
- msprobe/config.json +0 -2
- msprobe/core/advisor/advisor.py +8 -8
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +64 -3
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +54 -9
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +21 -11
- msprobe/core/common/utils.py +153 -167
- msprobe/core/common_config.py +18 -25
- msprobe/core/compare/acc_compare.py +209 -36
- msprobe/core/compare/check.py +102 -17
- msprobe/core/compare/compare_cli.py +21 -1
- msprobe/core/compare/highlight.py +41 -5
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +21 -6
- msprobe/core/compare/utils.py +82 -48
- msprobe/core/data_dump/data_collector.py +31 -32
- msprobe/core/data_dump/data_processor/base.py +45 -22
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
- msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +32 -16
- msprobe/core/grad_probe/constant.py +4 -0
- msprobe/core/grad_probe/grad_compare.py +2 -3
- msprobe/core/grad_probe/utils.py +16 -3
- msprobe/docs/01.installation.md +19 -9
- msprobe/docs/02.config_introduction.md +52 -80
- msprobe/docs/03.config_examples.md +3 -13
- msprobe/docs/04.acl_config_examples.md +11 -9
- msprobe/docs/05.data_dump_PyTorch.md +140 -12
- msprobe/docs/06.data_dump_MindSpore.md +47 -5
- msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/docs/17.grad_probe.md +14 -16
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
- msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
- msprobe/mindspore/cell_processor.py +27 -3
- msprobe/mindspore/common/const.py +2 -0
- msprobe/mindspore/common/utils.py +18 -2
- msprobe/mindspore/compare/distributed_compare.py +9 -22
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +173 -35
- msprobe/mindspore/compare/ms_graph_compare.py +27 -11
- msprobe/mindspore/debugger/debugger_config.py +16 -13
- msprobe/mindspore/debugger/precision_debugger.py +37 -13
- msprobe/mindspore/dump/dump_tool_factory.py +16 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +41 -17
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
- msprobe/mindspore/free_benchmark/common/utils.py +19 -5
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
- msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
- msprobe/mindspore/grad_probe/global_context.py +18 -8
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/service.py +42 -123
- msprobe/pytorch/__init__.py +20 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +20 -5
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +26 -11
- msprobe/pytorch/common/utils.py +40 -35
- msprobe/pytorch/compare/distributed_compare.py +11 -11
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +38 -6
- msprobe/pytorch/debugger/debugger_config.py +52 -39
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/enums.py +28 -0
- msprobe/pytorch/free_benchmark/common/params.py +15 -0
- msprobe/pytorch/free_benchmark/common/utils.py +17 -1
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +10 -11
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +17 -2
- msprobe/pytorch/online_dispatch/compare.py +11 -12
- msprobe/pytorch/online_dispatch/single_compare.py +7 -7
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
- msprobe/pytorch/online_dispatch/utils.py +1 -4
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +9 -10
- msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
- msprobe/pytorch/parse_tool/lib/utils.py +28 -24
- msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
- msprobe/pytorch/pt_config.py +167 -38
- msprobe/pytorch/service.py +97 -32
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
import multiprocessing
|
|
3
17
|
from dataclasses import dataclass
|
|
4
|
-
from functools import partial
|
|
5
|
-
import numpy as np
|
|
6
18
|
import pandas as pd
|
|
19
|
+
from tqdm import tqdm
|
|
7
20
|
from msprobe.core.common.log import logger
|
|
8
21
|
from msprobe.core.common.utils import CompareException
|
|
9
22
|
from msprobe.core.common.const import CompareConst
|
|
@@ -29,11 +42,19 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
29
42
|
except OSError as e:
|
|
30
43
|
logger.error("pool terminate failed")
|
|
31
44
|
|
|
45
|
+
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
46
|
+
|
|
47
|
+
def update_progress(size, progress_lock):
|
|
48
|
+
with progress_lock:
|
|
49
|
+
progress_bar.update(size)
|
|
50
|
+
|
|
32
51
|
for process_idx, df_chunk in enumerate(df_chunks):
|
|
33
52
|
idx = df_chunk_size * process_idx
|
|
53
|
+
chunk_size = len(df_chunk)
|
|
34
54
|
result = pool.apply_async(func,
|
|
35
55
|
args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
|
|
36
|
-
error_callback=err_call
|
|
56
|
+
error_callback=err_call,
|
|
57
|
+
callback=update_progress(chunk_size, lock))
|
|
37
58
|
results.append(result)
|
|
38
59
|
final_results = [r.get() for r in results]
|
|
39
60
|
pool.close()
|
|
@@ -42,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
|
|
|
42
63
|
|
|
43
64
|
|
|
44
65
|
def _ms_graph_handle_multi_process(func, result_df, mode):
|
|
45
|
-
process_num = int((multiprocessing.cpu_count() + 1) //
|
|
66
|
+
process_num = int((multiprocessing.cpu_count() + 1) // 4)
|
|
46
67
|
df_chunk_size = len(result_df) // process_num
|
|
47
68
|
if df_chunk_size > 0:
|
|
48
69
|
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
@@ -84,7 +105,8 @@ def read_dump_data(result_df):
|
|
|
84
105
|
except IndexError as e:
|
|
85
106
|
logger.error('result dataframe elements can not be access.')
|
|
86
107
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
87
|
-
|
|
108
|
+
|
|
109
|
+
|
|
88
110
|
@dataclass
|
|
89
111
|
class ComparisonResult:
|
|
90
112
|
cos_result: list
|
|
@@ -116,9 +138,12 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
|
116
138
|
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
117
139
|
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
118
140
|
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
119
|
-
result_df.loc[process_index, CompareConst.ACCURACY] =
|
|
120
|
-
|
|
121
|
-
result_df.loc[process_index, CompareConst.
|
|
141
|
+
result_df.loc[process_index, CompareConst.ACCURACY] = (
|
|
142
|
+
check_accuracy(result.cos_result[i], result.max_err_result[i]))
|
|
143
|
+
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
|
|
144
|
+
result.one_thousand_err_ratio_result)[i]
|
|
145
|
+
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
|
|
146
|
+
result.five_thousand_err_ratio_result)[i]
|
|
122
147
|
return result_df
|
|
123
148
|
except ValueError as e:
|
|
124
149
|
logger.error('result dataframe is not found.')
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import abc
|
|
2
17
|
import numpy as np
|
|
3
18
|
from msprobe.core.common.utils import format_value
|
|
@@ -78,10 +93,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
|
|
|
78
93
|
|
|
79
94
|
def npy_data_check(n_value, b_value):
|
|
80
95
|
error_message = ""
|
|
81
|
-
if n_value
|
|
82
|
-
error_message += "Dump file not
|
|
83
|
-
if n_value == "" or b_value == "":
|
|
84
|
-
error_message += "Dump file not found.\n"
|
|
96
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
97
|
+
error_message += "Dump file is not ndarray.\n"
|
|
85
98
|
|
|
86
99
|
# 检查 n_value 和 b_value 是否为空
|
|
87
100
|
if not error_message and (n_value.size == 0 or b_value.size == 0):
|
|
@@ -97,7 +110,8 @@ def npy_data_check(n_value, b_value):
|
|
|
97
110
|
|
|
98
111
|
if not error_message:
|
|
99
112
|
n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
|
|
100
|
-
|
|
113
|
+
# handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
|
|
114
|
+
if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
|
|
101
115
|
error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
|
|
102
116
|
if error_message == "":
|
|
103
117
|
error_flag = False
|
|
@@ -273,7 +287,8 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
|
|
|
273
287
|
relative_err = get_relative_err(n_value, b_value)
|
|
274
288
|
if not np.size(relative_err):
|
|
275
289
|
return CompareConst.NAN, ""
|
|
276
|
-
return format_value(
|
|
290
|
+
return format_value(
|
|
291
|
+
np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
|
|
277
292
|
|
|
278
293
|
|
|
279
294
|
class CompareOps:
|
msprobe/core/compare/utils.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
import os
|
|
3
17
|
import re
|
|
@@ -59,14 +73,18 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
59
73
|
|
|
60
74
|
def rename_api(npu_name, process):
|
|
61
75
|
npu_split = npu_name.split(process)
|
|
62
|
-
|
|
76
|
+
try:
|
|
77
|
+
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
78
|
+
except IndexError as error:
|
|
79
|
+
logger.error(f'{npu_name} can not be split with {process}, please check!')
|
|
80
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
63
81
|
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
64
82
|
torch_func = str(torch_func_split[0]) + str(in_out)
|
|
65
83
|
return torch_func
|
|
66
84
|
|
|
67
85
|
|
|
68
86
|
def read_op(op_data, op_name):
|
|
69
|
-
op_parsed_list =
|
|
87
|
+
op_parsed_list = []
|
|
70
88
|
if Const.FORWARD in op_name:
|
|
71
89
|
if Const.INPUT_ARGS in op_data:
|
|
72
90
|
input_item = op_data[Const.INPUT_ARGS]
|
|
@@ -103,16 +121,23 @@ def read_op(op_data, op_name):
|
|
|
103
121
|
return op_parsed_list
|
|
104
122
|
|
|
105
123
|
|
|
106
|
-
def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
|
|
124
|
+
def op_item_parse(item, op_name, index, item_list=None, top_bool=True, depth=0):
|
|
125
|
+
if depth > Const.MAX_DEPTH:
|
|
126
|
+
logger.error(f"parse of api/module of {op_name} exceeds the recursion limit.")
|
|
127
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
107
128
|
if item_list is None:
|
|
108
129
|
item_list = []
|
|
109
130
|
if item is None or (isinstance(item, dict) and not item):
|
|
110
131
|
if not top_bool:
|
|
111
|
-
tmp = {
|
|
112
|
-
|
|
132
|
+
tmp = {
|
|
133
|
+
'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
|
|
134
|
+
'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'
|
|
135
|
+
}
|
|
113
136
|
else:
|
|
114
|
-
tmp = {
|
|
115
|
-
|
|
137
|
+
tmp = {
|
|
138
|
+
'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
|
|
139
|
+
'shape': None, 'md5': None, 'data_name': '-1'
|
|
140
|
+
}
|
|
116
141
|
item_list.append(tmp)
|
|
117
142
|
return item_list
|
|
118
143
|
if index is None:
|
|
@@ -125,7 +150,7 @@ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
|
|
|
125
150
|
if isinstance(item, dict):
|
|
126
151
|
if 'type' not in item:
|
|
127
152
|
for kwarg in item:
|
|
128
|
-
kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
|
|
153
|
+
kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
|
|
129
154
|
item_list += kwarg_parsed_list
|
|
130
155
|
kwarg_parsed_list.clear()
|
|
131
156
|
elif 'dtype' in item:
|
|
@@ -171,7 +196,7 @@ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
|
|
|
171
196
|
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
172
197
|
else:
|
|
173
198
|
for j, item_spec in enumerate(item):
|
|
174
|
-
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
|
|
199
|
+
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
|
|
175
200
|
return item_list
|
|
176
201
|
|
|
177
202
|
|
|
@@ -226,9 +251,10 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
|
|
|
226
251
|
b_struct = b_dict[key][index]
|
|
227
252
|
err_msg = ""
|
|
228
253
|
if md5_compare:
|
|
229
|
-
result_item = [
|
|
230
|
-
|
|
231
|
-
|
|
254
|
+
result_item = [
|
|
255
|
+
n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2],
|
|
256
|
+
CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF
|
|
257
|
+
]
|
|
232
258
|
if has_stack and index == 0 and key == "input_struct":
|
|
233
259
|
result_item.extend(npu_stack_info)
|
|
234
260
|
else:
|
|
@@ -237,15 +263,19 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
|
|
|
237
263
|
continue
|
|
238
264
|
|
|
239
265
|
if summary_compare:
|
|
240
|
-
result_item = [
|
|
241
|
-
|
|
266
|
+
result_item = [
|
|
267
|
+
n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
268
|
+
" ", " ", " ", " ", " ", " ", " ", " "
|
|
269
|
+
]
|
|
242
270
|
else:
|
|
243
|
-
result_item = [
|
|
244
|
-
|
|
271
|
+
result_item = [
|
|
272
|
+
n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
273
|
+
" ", " ", " ", " ", " "
|
|
274
|
+
]
|
|
245
275
|
|
|
246
|
-
npu_summary_data = n_dict.get(
|
|
276
|
+
npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
247
277
|
result_item.extend(npu_summary_data)
|
|
248
|
-
bench_summary_data = b_dict.get(
|
|
278
|
+
bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
|
|
249
279
|
result_item.extend(bench_summary_data)
|
|
250
280
|
|
|
251
281
|
if summary_compare:
|
|
@@ -257,7 +287,7 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
|
|
|
257
287
|
if bench_val != 0:
|
|
258
288
|
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
259
289
|
else:
|
|
260
|
-
relative =
|
|
290
|
+
relative = CompareConst.N_A
|
|
261
291
|
result_item[start_idx + i] = diff
|
|
262
292
|
result_item[start_idx + i + 4] = relative
|
|
263
293
|
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
@@ -287,15 +317,19 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
|
|
|
287
317
|
n_name = n_dict['op_name'][n_start + index]
|
|
288
318
|
n_struct = n_dict[key][index]
|
|
289
319
|
if md5_compare:
|
|
290
|
-
result_item = [
|
|
291
|
-
|
|
320
|
+
result_item = [
|
|
321
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
322
|
+
n_struct[2], CompareConst.NAN, CompareConst.NAN
|
|
323
|
+
]
|
|
292
324
|
result.append(result_item)
|
|
293
325
|
continue
|
|
294
|
-
result_item = [
|
|
295
|
-
|
|
296
|
-
|
|
326
|
+
result_item = [
|
|
327
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
328
|
+
" ", " ", " ", " ", " "
|
|
329
|
+
]
|
|
330
|
+
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
297
331
|
result_item.extend(summary_data)
|
|
298
|
-
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(
|
|
332
|
+
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
|
|
299
333
|
result_item.extend(summary_data)
|
|
300
334
|
|
|
301
335
|
err_msg = ""
|
|
@@ -313,15 +347,12 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
|
|
|
313
347
|
|
|
314
348
|
n_num = len(n_dict['op_name'])
|
|
315
349
|
b_num = len(b_dict['op_name'])
|
|
316
|
-
n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
|
|
317
|
-
b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
n_num_output = n_num - n_num_input - n_num_kwarg
|
|
321
|
-
b_num_output = b_num - b_num_input - b_num_kwarg
|
|
350
|
+
n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
351
|
+
b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
352
|
+
n_num_output = n_num - n_num_input
|
|
353
|
+
b_num_output = b_num - b_num_input
|
|
322
354
|
get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
|
|
323
|
-
get_accuracy_core(n_num_input,
|
|
324
|
-
get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
|
|
355
|
+
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
|
|
325
356
|
|
|
326
357
|
|
|
327
358
|
def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
@@ -331,7 +362,8 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
331
362
|
err_msg = CompareConst.NO_BENCH
|
|
332
363
|
accuracy_check_res = CompareConst.N_A
|
|
333
364
|
for index, n_name in enumerate(n_dict["op_name"]):
|
|
334
|
-
|
|
365
|
+
name_ele_list = n_name.split(Const.SEP)
|
|
366
|
+
if "input" in name_ele_list:
|
|
335
367
|
n_struct = n_dict["input_struct"][index]
|
|
336
368
|
else:
|
|
337
369
|
n_struct = n_dict["output_struct"][index_out]
|
|
@@ -383,25 +415,28 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
|
383
415
|
op_dict['stack_info'].append(tensor['full_info'])
|
|
384
416
|
break
|
|
385
417
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
418
|
+
name_ele_list = tensor['full_op_name'].split(Const.SEP)
|
|
386
419
|
if not md5_compare:
|
|
387
|
-
if
|
|
420
|
+
if "input" in name_ele_list:
|
|
388
421
|
op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
|
|
389
|
-
elif
|
|
422
|
+
elif "kwarg" in name_ele_list:
|
|
390
423
|
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
|
|
391
|
-
elif
|
|
424
|
+
elif "output" in name_ele_list:
|
|
392
425
|
op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
|
|
393
426
|
else:
|
|
394
|
-
if
|
|
427
|
+
if "input" in name_ele_list:
|
|
395
428
|
op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
396
|
-
|
|
429
|
+
if "kwarg" in name_ele_list:
|
|
397
430
|
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
398
|
-
elif
|
|
431
|
+
elif "output" in name_ele_list:
|
|
399
432
|
op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
400
|
-
|
|
401
433
|
op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
|
|
402
434
|
|
|
403
435
|
if all_mode_bool:
|
|
404
436
|
op_dict["data_name"].append(tensor['data_name'])
|
|
437
|
+
data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
|
|
438
|
+
if data_name != "-1":
|
|
439
|
+
op_dict["op_name"][-1] = data_name
|
|
405
440
|
|
|
406
441
|
if not op_dict["kwargs_struct"]:
|
|
407
442
|
del op_dict["kwargs_struct"]
|
|
@@ -410,7 +445,7 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
|
410
445
|
|
|
411
446
|
def _compare_parser(parser):
|
|
412
447
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
413
|
-
help="<Required> The compare input path, a dict json.",
|
|
448
|
+
help="<Required> The compare input path, a dict json.", required=True)
|
|
414
449
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
415
450
|
help="<Required> The compare task result out path.", required=True)
|
|
416
451
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
@@ -422,9 +457,8 @@ def _compare_parser(parser):
|
|
|
422
457
|
parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
|
|
423
458
|
help="<optional> The cell mapping file path.", required=False)
|
|
424
459
|
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
425
|
-
help="<optional> The api mapping file path.", required=False)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
460
|
+
help="<optional> The api mapping file path.", required=False)
|
|
461
|
+
parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
|
|
462
|
+
help="<optional> The data mapping file path.", required=False)
|
|
463
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
464
|
+
help="<optional> The layer mapping file path.", required=False)
|
|
@@ -1,9 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
from msprobe.core.data_dump.scope import build_scope, ListScope
|
|
4
19
|
from msprobe.core.data_dump.json_writer import DataWriter
|
|
5
20
|
from msprobe.core.common.log import logger
|
|
6
|
-
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
7
22
|
from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
|
|
8
23
|
|
|
9
24
|
|
|
@@ -14,14 +29,13 @@ def build_data_collector(config):
|
|
|
14
29
|
class DataCollector:
|
|
15
30
|
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
16
31
|
tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
|
|
17
|
-
level_without_construct = [
|
|
32
|
+
level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
|
|
18
33
|
|
|
19
34
|
def __init__(self, config):
|
|
20
35
|
self.config = config
|
|
21
36
|
self.data_writer = DataWriter()
|
|
22
37
|
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
|
|
23
|
-
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
24
|
-
if self.config.framework == Const.PT_FRAMEWORK else None
|
|
38
|
+
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
25
39
|
self.module_count = {}
|
|
26
40
|
if self.config.task == Const.FREE_BENCHMARK:
|
|
27
41
|
self.scope = build_scope(ListScope, self.config.scope, self.config.list)
|
|
@@ -59,16 +73,16 @@ class DataCollector:
|
|
|
59
73
|
def write_json(self):
|
|
60
74
|
self.data_writer.write_json()
|
|
61
75
|
|
|
62
|
-
def update_data(self,
|
|
76
|
+
def update_data(self, name, data_info):
|
|
77
|
+
msg = f"msprobe is collecting data on {name}."
|
|
63
78
|
if self.config.task == Const.OVERFLOW_CHECK:
|
|
64
79
|
if self.data_processor.has_overflow:
|
|
80
|
+
msg += " Overflow detected."
|
|
81
|
+
logger.warning(msg)
|
|
65
82
|
self.data_writer.update_data(data_info)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
else:
|
|
70
|
-
self.data_writer.update_data(data_info)
|
|
71
|
-
return msg
|
|
83
|
+
return
|
|
84
|
+
logger.debug(msg)
|
|
85
|
+
self.data_writer.update_data(data_info)
|
|
72
86
|
|
|
73
87
|
def pre_forward_data_collect(self, name, module, pid, module_input_output):
|
|
74
88
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
@@ -78,7 +92,7 @@ class DataCollector:
|
|
|
78
92
|
return
|
|
79
93
|
logger.info(f"API {name} is inplace.")
|
|
80
94
|
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
81
|
-
self.handle_data(name, data_info)
|
|
95
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
82
96
|
|
|
83
97
|
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
84
98
|
self.update_construct(name)
|
|
@@ -92,13 +106,7 @@ class DataCollector:
|
|
|
92
106
|
if self.config.level == "L2":
|
|
93
107
|
return
|
|
94
108
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
95
|
-
|
|
96
|
-
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
97
|
-
else:
|
|
98
|
-
if self.data_processor.is_terminated:
|
|
99
|
-
self.handle_data(name, data_info, flush=True)
|
|
100
|
-
raise Exception(f"[{Const.TOOL_NAME}] exit")
|
|
101
|
-
self.handle_data(name, data_info)
|
|
109
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
102
110
|
|
|
103
111
|
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
104
112
|
self.update_construct(name)
|
|
@@ -106,13 +114,7 @@ class DataCollector:
|
|
|
106
114
|
return
|
|
107
115
|
|
|
108
116
|
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
109
|
-
|
|
110
|
-
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
111
|
-
else:
|
|
112
|
-
if self.data_processor.is_terminated:
|
|
113
|
-
self.handle_data(name, data_info, flush=True)
|
|
114
|
-
raise Exception(f"[{Const.TOOL_NAME}] exit")
|
|
115
|
-
self.handle_data(name, data_info)
|
|
117
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
116
118
|
|
|
117
119
|
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
118
120
|
self.update_construct(name)
|
|
@@ -131,18 +133,15 @@ class DataCollector:
|
|
|
131
133
|
self.handle_data(name, data_info)
|
|
132
134
|
|
|
133
135
|
def update_construct(self, name):
|
|
134
|
-
if self.config.
|
|
135
|
-
self.config.level not in DataCollector.level_without_construct:
|
|
136
|
+
if self.config.level not in DataCollector.level_without_construct:
|
|
136
137
|
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
137
138
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
138
139
|
|
|
139
140
|
def handle_data(self, name, data_info, flush=False):
|
|
140
141
|
if data_info:
|
|
141
|
-
|
|
142
|
-
msg = self.update_data(data_info, msg)
|
|
143
|
-
logger.info(MsgConst.CLEAR_SYMBOL + msg, end='\r')
|
|
142
|
+
self.update_data(name, data_info)
|
|
144
143
|
if not flush:
|
|
145
|
-
self.data_writer.
|
|
144
|
+
self.data_writer.flush_data_periodically()
|
|
146
145
|
else:
|
|
147
146
|
self.write_json()
|
|
148
147
|
|
|
@@ -1,11 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import inspect
|
|
17
|
+
import os
|
|
3
18
|
from dataclasses import dataclass
|
|
4
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
|
+
|
|
5
21
|
import numpy as np
|
|
6
|
-
from msprobe.core.common.log import logger
|
|
7
|
-
from msprobe.core.common.utils import convert_tuple
|
|
8
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.common.utils import convert_tuple, CompareException
|
|
9
25
|
|
|
10
26
|
|
|
11
27
|
@dataclass
|
|
@@ -69,8 +85,11 @@ class TensorStatInfo:
|
|
|
69
85
|
|
|
70
86
|
class BaseDataProcessor:
|
|
71
87
|
_recursive_key_stack = []
|
|
72
|
-
special_type = (
|
|
73
|
-
|
|
88
|
+
special_type = (
|
|
89
|
+
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
90
|
+
bool, int, float, str, slice,
|
|
91
|
+
type(Ellipsis)
|
|
92
|
+
)
|
|
74
93
|
|
|
75
94
|
def __init__(self, config, data_writer):
|
|
76
95
|
self.data_writer = data_writer
|
|
@@ -86,26 +105,27 @@ class BaseDataProcessor:
|
|
|
86
105
|
@property
|
|
87
106
|
def data_path(self):
|
|
88
107
|
return self.data_writer.dump_tensor_data_dir
|
|
89
|
-
|
|
108
|
+
|
|
90
109
|
@property
|
|
91
110
|
def is_terminated(self):
|
|
92
111
|
return False
|
|
93
112
|
|
|
94
113
|
@staticmethod
|
|
95
114
|
def analyze_api_call_stack(name):
|
|
115
|
+
try:
|
|
116
|
+
api_stack = inspect.stack()[5:]
|
|
117
|
+
except Exception as e:
|
|
118
|
+
logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
|
|
119
|
+
api_stack = None
|
|
96
120
|
stack_str = []
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
"File
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
" ".join(["\n", code[0].strip()])
|
|
106
|
-
])
|
|
107
|
-
])
|
|
108
|
-
stack_str.append(stack_line)
|
|
121
|
+
if api_stack:
|
|
122
|
+
for (_, path, line, func, code, _) in api_stack:
|
|
123
|
+
if not code:
|
|
124
|
+
continue
|
|
125
|
+
stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
|
|
126
|
+
stack_str.append(stack_line)
|
|
127
|
+
else:
|
|
128
|
+
stack_str.append(Const.WITHOUT_CALL_STACK)
|
|
109
129
|
stack_info_struct = {name: stack_str}
|
|
110
130
|
return stack_info_struct
|
|
111
131
|
|
|
@@ -167,7 +187,10 @@ class BaseDataProcessor:
|
|
|
167
187
|
return cls.special_type
|
|
168
188
|
|
|
169
189
|
@classmethod
|
|
170
|
-
def recursive_apply_transform(cls, args, transform):
|
|
190
|
+
def recursive_apply_transform(cls, args, transform, depth=0):
|
|
191
|
+
if depth > Const.MAX_DEPTH:
|
|
192
|
+
logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
|
|
193
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
171
194
|
if isinstance(args, cls.get_special_types()):
|
|
172
195
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
173
196
|
return arg_transform
|
|
@@ -175,14 +198,14 @@ class BaseDataProcessor:
|
|
|
175
198
|
result_list = []
|
|
176
199
|
for i, arg in enumerate(args):
|
|
177
200
|
cls._recursive_key_stack.append(str(i))
|
|
178
|
-
result_list.append(cls.recursive_apply_transform(arg, transform))
|
|
201
|
+
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
179
202
|
cls._recursive_key_stack.pop()
|
|
180
203
|
return type(args)(result_list)
|
|
181
204
|
elif isinstance(args, dict):
|
|
182
205
|
result_dict = {}
|
|
183
206
|
for k, arg in args.items():
|
|
184
207
|
cls._recursive_key_stack.append(str(k))
|
|
185
|
-
result_dict[k] = cls.recursive_apply_transform(arg, transform)
|
|
208
|
+
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
186
209
|
cls._recursive_key_stack.pop()
|
|
187
210
|
return result_dict
|
|
188
211
|
elif args is not None:
|
|
@@ -222,7 +245,7 @@ class BaseDataProcessor:
|
|
|
222
245
|
|
|
223
246
|
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
224
247
|
pass
|
|
225
|
-
|
|
248
|
+
|
|
226
249
|
def analyze_element(self, element):
|
|
227
250
|
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
228
251
|
|