mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -20,33 +20,45 @@ import zlib
|
|
|
20
20
|
from dataclasses import dataclass
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
+
import pandas as pd
|
|
23
24
|
|
|
24
25
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
25
26
|
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
26
27
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
27
28
|
|
|
29
|
+
json_file_mapping = {
|
|
30
|
+
Const.DUMP_JSON_FILE: "dump.json",
|
|
31
|
+
Const.DEBUG_JSON_FILE: "debug.json",
|
|
32
|
+
Const.STACK_JSON_FILE: "stack.json"
|
|
33
|
+
}
|
|
28
34
|
|
|
29
|
-
|
|
35
|
+
|
|
36
|
+
def extract_json(dirname, json_file_type):
|
|
30
37
|
json_path = ''
|
|
31
38
|
for filename in os.listdir(dirname):
|
|
32
|
-
target_file_name =
|
|
39
|
+
target_file_name = json_file_mapping.get(json_file_type)
|
|
40
|
+
if target_file_name is None:
|
|
41
|
+
logger.error(f'extract_json failed, invalid json_file_type: {json_file_type}.')
|
|
42
|
+
raise CompareException(CompareException.INVALID_KEY_ERROR)
|
|
33
43
|
if filename == target_file_name:
|
|
34
44
|
json_path = os.path.join(dirname, filename)
|
|
35
45
|
break
|
|
36
46
|
|
|
37
47
|
# Provide robustness on invalid directory inputs
|
|
38
48
|
if not json_path:
|
|
39
|
-
if
|
|
49
|
+
if json_file_type == Const.STACK_JSON_FILE:
|
|
40
50
|
logger.warning(f'stack.json is not found in dump dir {dirname}.')
|
|
41
|
-
|
|
51
|
+
elif json_file_type == Const.DUMP_JSON_FILE:
|
|
42
52
|
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
43
|
-
|
|
53
|
+
elif json_file_type == Const.DEBUG_JSON_FILE:
|
|
54
|
+
logger.warning(f'debug.json is not found in dump dir {dirname}.')
|
|
55
|
+
|
|
44
56
|
return json_path
|
|
45
57
|
|
|
46
58
|
|
|
47
59
|
def set_stack_json_path(input_param):
|
|
48
60
|
npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
|
|
49
|
-
stack_path = extract_json(npu_data_dir,
|
|
61
|
+
stack_path = extract_json(npu_data_dir, json_file_type=Const.STACK_JSON_FILE)
|
|
50
62
|
input_param["stack_json_path"] = stack_path if stack_path else None
|
|
51
63
|
return bool(stack_path)
|
|
52
64
|
|
|
@@ -81,24 +93,9 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
81
93
|
return contents
|
|
82
94
|
|
|
83
95
|
|
|
84
|
-
def rename_api(npu_name, process):
|
|
85
|
-
"""
|
|
86
|
-
原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
|
|
87
|
-
rename后: {api_type}.{api_name}.{input/output}.{参数序号}
|
|
88
|
-
"""
|
|
89
|
-
npu_split = npu_name.split(process)
|
|
90
|
-
try:
|
|
91
|
-
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
92
|
-
except IndexError as error:
|
|
93
|
-
logger.error(f'{npu_name} can not be split with {process}, please check!')
|
|
94
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
95
|
-
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
96
|
-
torch_func = str(torch_func_split[0]) + str(in_out)
|
|
97
|
-
return torch_func
|
|
98
|
-
|
|
99
|
-
|
|
100
96
|
def read_op(op_data, op_name):
|
|
101
|
-
|
|
97
|
+
split_name = op_name.split(Const.SEP)
|
|
98
|
+
if Const.DEBUG in split_name or Const.PARAMS_GRAD in split_name:
|
|
102
99
|
op_parsed_list = op_item_parse(op_data, op_name)
|
|
103
100
|
else:
|
|
104
101
|
op_parsed_list = []
|
|
@@ -191,35 +188,152 @@ def gen_op_item(op_data, op_name):
|
|
|
191
188
|
return op_item
|
|
192
189
|
|
|
193
190
|
|
|
194
|
-
|
|
191
|
+
@dataclass
|
|
192
|
+
class ApiItemInfo:
|
|
193
|
+
name: str
|
|
194
|
+
struct: tuple
|
|
195
|
+
stack_info: list
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def merge_tensor(tensor_list, dump_mode):
|
|
199
|
+
keys = [
|
|
200
|
+
CompareConst.OP_NAME,
|
|
201
|
+
CompareConst.INPUT_STRUCT,
|
|
202
|
+
CompareConst.KWARGS_STRUCT,
|
|
203
|
+
CompareConst.OUTPUT_STRUCT,
|
|
204
|
+
CompareConst.PARAMS_STRUCT,
|
|
205
|
+
CompareConst.PARAMS_GRAD_STRUCT,
|
|
206
|
+
CompareConst.DEBUG_STRUCT,
|
|
207
|
+
Const.SUMMARY,
|
|
208
|
+
Const.STACK_INFO
|
|
209
|
+
]
|
|
210
|
+
op_dict = {key: [] for key in keys}
|
|
211
|
+
|
|
212
|
+
if dump_mode == Const.ALL:
|
|
213
|
+
op_dict["data_name"] = []
|
|
214
|
+
|
|
215
|
+
for tensor in tensor_list:
|
|
216
|
+
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
217
|
+
if len(tensor) == 2:
|
|
218
|
+
op_dict[Const.STACK_INFO].append(tensor['full_info'])
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
op_dict[CompareConst.OP_NAME].append(tensor['full_op_name'])
|
|
222
|
+
|
|
223
|
+
_, state = get_name_and_state(tensor['full_op_name'])
|
|
224
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
225
|
+
if not struct_key:
|
|
226
|
+
continue
|
|
227
|
+
if dump_mode == Const.MD5:
|
|
228
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
229
|
+
else:
|
|
230
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
231
|
+
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
232
|
+
|
|
233
|
+
if dump_mode == Const.ALL:
|
|
234
|
+
op_dict["data_name"].append(tensor['data_name'])
|
|
235
|
+
|
|
236
|
+
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
237
|
+
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
238
|
+
return op_dict if op_dict[CompareConst.OP_NAME] else {}
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def print_compare_ends_info():
|
|
242
|
+
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
243
|
+
logger.info('*' * total_len)
|
|
244
|
+
logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
245
|
+
logger.info('*' * total_len)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def table_value_is_valid(value: str) -> bool:
|
|
249
|
+
if not isinstance(value, str):
|
|
250
|
+
return True
|
|
251
|
+
try:
|
|
252
|
+
# -1.00 or +1.00 should be considered as digit numbers
|
|
253
|
+
float(value)
|
|
254
|
+
except ValueError:
|
|
255
|
+
# otherwise, they will be considered as formular injections
|
|
256
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
257
|
+
return True
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def get_name_and_state(name):
|
|
195
261
|
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
}
|
|
209
|
-
}
|
|
210
|
-
Parameter:
|
|
211
|
-
data_dict: 字典格式的数据
|
|
212
|
-
full_op_name: 参数的全名字符串
|
|
213
|
-
item_list: 参数信息集合
|
|
262
|
+
Get api/module name and state
|
|
263
|
+
example:
|
|
264
|
+
name = 'conv2d.forward.1.input.0'
|
|
265
|
+
return: ('conv2d.forward.1.', 'input')
|
|
266
|
+
|
|
267
|
+
name = 'Functional.pad.0.backward.output.0'
|
|
268
|
+
return: ('Functional.pad.0.backward.', 'output')
|
|
269
|
+
|
|
270
|
+
name = 'x_tensor.0.debug.{index}'
|
|
271
|
+
return: ('x_tensor.0.', 'debug')
|
|
272
|
+
|
|
273
|
+
state type: input, output, kwargs, parameters, parameters_grad, debug
|
|
214
274
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
275
|
+
if not isinstance(name, str):
|
|
276
|
+
logger.error(f'Invalid name: {name}, type should be string, please check.')
|
|
277
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
278
|
+
|
|
279
|
+
if Const.DEBUG in name.split(Const.SEP):
|
|
280
|
+
return name.split(Const.DEBUG)[0], Const.DEBUG
|
|
281
|
+
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
282
|
+
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
283
|
+
|
|
284
|
+
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
285
|
+
if len(split) < 3:
|
|
286
|
+
logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
|
|
287
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
288
|
+
api = f'{split[0]}.{split[1]}.'
|
|
289
|
+
state_str = split[2]
|
|
290
|
+
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
291
|
+
if not match:
|
|
292
|
+
raise CompareException(f'Invalid name string: {name}')
|
|
293
|
+
if match.group(1):
|
|
294
|
+
api = f'{api}{match.group(1)}'
|
|
295
|
+
state = match.group(2)
|
|
296
|
+
return api, state
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def reorder_op_name_list(op_name_list):
|
|
300
|
+
if not op_name_list:
|
|
301
|
+
return op_name_list
|
|
302
|
+
|
|
303
|
+
parameters = []
|
|
304
|
+
output = []
|
|
305
|
+
parameters_grad = []
|
|
306
|
+
others = []
|
|
307
|
+
for x in op_name_list:
|
|
308
|
+
state = get_name_and_state(x)[1]
|
|
309
|
+
if state == Const.PARAMS:
|
|
310
|
+
parameters.append(x)
|
|
311
|
+
elif state == Const.OUTPUT:
|
|
312
|
+
output.append(x)
|
|
313
|
+
elif state == Const.PARAMS_GRAD:
|
|
314
|
+
parameters_grad.append(x)
|
|
315
|
+
else:
|
|
316
|
+
others.append(x)
|
|
317
|
+
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
318
|
+
op_name_reorder = others + parameters + output + parameters_grad
|
|
319
|
+
return op_name_reorder
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
323
|
+
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
324
|
+
if not op_name_list or not summary_list:
|
|
325
|
+
return op_name_list, summary_list, data_name_list
|
|
326
|
+
|
|
327
|
+
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
328
|
+
|
|
329
|
+
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
330
|
+
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
331
|
+
if data_name_list:
|
|
332
|
+
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
333
|
+
else:
|
|
334
|
+
data_name_reorder = data_name_list
|
|
335
|
+
|
|
336
|
+
return op_name_reorder, summary_reorder, data_name_reorder
|
|
223
337
|
|
|
224
338
|
|
|
225
339
|
def process_summary_data(summary_data):
|
|
@@ -285,9 +399,9 @@ def result_item_init(n_info, b_info, dump_mode):
|
|
|
285
399
|
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
286
400
|
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
287
401
|
elif dump_mode == Const.SUMMARY:
|
|
288
|
-
result_item.extend([" "] * 8)
|
|
402
|
+
result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
|
|
289
403
|
else:
|
|
290
|
-
result_item.extend([" "] *
|
|
404
|
+
result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
|
|
291
405
|
else:
|
|
292
406
|
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
293
407
|
f"npu_info_struct is {n_info.struct}\n" \
|
|
@@ -321,8 +435,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
321
435
|
has_stack = npu_stack_info and bench_stack_info
|
|
322
436
|
|
|
323
437
|
if dump_mode == Const.ALL:
|
|
324
|
-
|
|
325
|
-
|
|
438
|
+
npu_data_name_list = n_dict.get("data_name", None)
|
|
439
|
+
bench_data_name_list = b_dict.get("data_name", None)
|
|
326
440
|
|
|
327
441
|
for index in range(min_len):
|
|
328
442
|
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
@@ -353,7 +467,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
353
467
|
result_item.append(err_msg)
|
|
354
468
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
355
469
|
if dump_mode == Const.ALL:
|
|
356
|
-
|
|
470
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
471
|
+
bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
|
|
472
|
+
result_item.append([npu_data_name, bench_data_name])
|
|
357
473
|
|
|
358
474
|
result.append(result_item)
|
|
359
475
|
|
|
@@ -371,7 +487,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
371
487
|
continue
|
|
372
488
|
result_item = [
|
|
373
489
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
374
|
-
" ", " ", " ", " ", " "
|
|
490
|
+
" ", " ", " ", " ", " ", " "
|
|
375
491
|
]
|
|
376
492
|
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
377
493
|
result_item.extend(summary_data)
|
|
@@ -388,7 +504,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
388
504
|
result_item.append(err_msg)
|
|
389
505
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
390
506
|
if dump_mode == Const.ALL:
|
|
391
|
-
|
|
507
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
508
|
+
result_item.append([npu_data_name, "-1"])
|
|
392
509
|
|
|
393
510
|
result.append(result_item)
|
|
394
511
|
|
|
@@ -404,197 +521,23 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
404
521
|
CompareConst.PARAMS_GRAD_STRUCT)
|
|
405
522
|
|
|
406
523
|
|
|
407
|
-
def
|
|
408
|
-
|
|
409
|
-
if npu_stack_info and index == 0:
|
|
410
|
-
result_item.extend(npu_stack_info)
|
|
411
|
-
else:
|
|
412
|
-
result_item.append(CompareConst.NONE)
|
|
413
|
-
|
|
524
|
+
def make_result_table(result, dump_mode, stack_mode):
|
|
525
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
|
|
414
526
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
418
|
-
|
|
419
|
-
struct_to_index_mapping = {
|
|
420
|
-
CompareConst.INPUT_STRUCT: 0,
|
|
421
|
-
CompareConst.OUTPUT_STRUCT: 0,
|
|
422
|
-
CompareConst.PARAMS_STRUCT: 0,
|
|
423
|
-
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
424
|
-
}
|
|
425
|
-
|
|
426
|
-
op_name_list = n_dict.get(CompareConst.OP_NAME)
|
|
427
|
-
summary_list = n_dict.get(Const.SUMMARY)
|
|
428
|
-
data_name_list = n_dict.get('data_name')
|
|
429
|
-
op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
|
|
430
|
-
summary_list,
|
|
431
|
-
data_name_list)
|
|
432
|
-
for index, n_name in enumerate(op_name_reorder):
|
|
433
|
-
_, state = get_name_and_state(n_name)
|
|
434
|
-
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
435
|
-
if not struct_key:
|
|
436
|
-
continue
|
|
437
|
-
n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
|
|
438
|
-
struct_to_index_mapping[struct_key] += 1
|
|
439
|
-
|
|
440
|
-
try:
|
|
441
|
-
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
442
|
-
except IndexError as e:
|
|
443
|
-
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
444
|
-
f"op_name of n_dict is {n_dict['op_name']}\n" \
|
|
445
|
-
f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
|
|
446
|
-
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
447
|
-
logger.error(err_msg)
|
|
448
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
449
|
-
|
|
450
|
-
if dump_mode == Const.MD5:
|
|
451
|
-
result_item.extend([CompareConst.N_A] * 3)
|
|
452
|
-
append_stack_info(result_item, npu_stack_info, index)
|
|
453
|
-
result.append(result_item)
|
|
454
|
-
continue
|
|
455
|
-
if dump_mode == Const.SUMMARY:
|
|
456
|
-
result_item.extend([CompareConst.N_A] * 8)
|
|
527
|
+
if stack_mode:
|
|
528
|
+
header.append(CompareConst.STACK)
|
|
457
529
|
if dump_mode == Const.ALL:
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
461
|
-
bench_summary_data = [CompareConst.N_A] * 4
|
|
462
|
-
result_item.extend(npu_summary_data)
|
|
463
|
-
result_item.extend(bench_summary_data)
|
|
464
|
-
err_msg = CompareConst.NO_BENCH
|
|
465
|
-
accuracy_check_res = CompareConst.N_A
|
|
466
|
-
result_item.append(accuracy_check_res)
|
|
467
|
-
result_item.append(err_msg)
|
|
468
|
-
append_stack_info(result_item, npu_stack_info, index)
|
|
469
|
-
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
470
|
-
result_item.extend(["-1"])
|
|
471
|
-
result.append(result_item)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
def merge_tensor(tensor_list, dump_mode):
|
|
475
|
-
op_dict = {}
|
|
476
|
-
op_dict["op_name"] = []
|
|
477
|
-
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
478
|
-
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
479
|
-
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
480
|
-
op_dict[CompareConst.PARAMS_STRUCT] = []
|
|
481
|
-
op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
|
|
482
|
-
op_dict[Const.SUMMARY] = []
|
|
483
|
-
op_dict["stack_info"] = []
|
|
484
|
-
|
|
485
|
-
if dump_mode == Const.ALL:
|
|
486
|
-
op_dict["data_name"] = []
|
|
487
|
-
|
|
488
|
-
for tensor in tensor_list:
|
|
489
|
-
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
490
|
-
if len(tensor) == 2:
|
|
491
|
-
op_dict['stack_info'].append(tensor['full_info'])
|
|
492
|
-
break
|
|
493
|
-
|
|
494
|
-
op_dict["op_name"].append(tensor['full_op_name'])
|
|
495
|
-
|
|
496
|
-
_, state = get_name_and_state(tensor['full_op_name'])
|
|
497
|
-
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
498
|
-
if not struct_key:
|
|
499
|
-
continue
|
|
500
|
-
if dump_mode == Const.MD5:
|
|
501
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
502
|
-
else:
|
|
503
|
-
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
504
|
-
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
505
|
-
|
|
530
|
+
header.append(CompareConst.DATA_NAME)
|
|
531
|
+
else:
|
|
506
532
|
if dump_mode == Const.ALL:
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
511
|
-
return op_dict if op_dict["op_name"] else {}
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
def print_compare_ends_info():
|
|
515
|
-
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
516
|
-
logger.info('*' * total_len)
|
|
517
|
-
logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
518
|
-
logger.info('*' * total_len)
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
def table_value_is_valid(value: str) -> bool:
|
|
522
|
-
if not isinstance(value, str):
|
|
523
|
-
return True
|
|
524
|
-
try:
|
|
525
|
-
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
526
|
-
float(value)
|
|
527
|
-
except ValueError:
|
|
528
|
-
# otherwise, they will be considered as formular injections
|
|
529
|
-
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
530
|
-
return True
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
def get_name_and_state(name):
|
|
534
|
-
"""
|
|
535
|
-
Get api/module name and state
|
|
536
|
-
example:
|
|
537
|
-
name = 'conv2d.forward.1.input.0'
|
|
538
|
-
return: ('conv2d.forward.1.', 'input')
|
|
539
|
-
|
|
540
|
-
name = 'Functional.pad.0.backward.output.0'
|
|
541
|
-
return: ('Functional.pad.0.backward.', 'output')
|
|
542
|
-
|
|
543
|
-
state type: input, output, kwargs, parameters, parameters_grad
|
|
544
|
-
"""
|
|
545
|
-
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
546
|
-
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
547
|
-
|
|
548
|
-
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
549
|
-
api = f'{split[0]}.{split[1]}.'
|
|
550
|
-
state_str = split[2]
|
|
551
|
-
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
552
|
-
if not match:
|
|
553
|
-
raise CompareException(f'Invalid name string: {name}')
|
|
554
|
-
if match.group(1):
|
|
555
|
-
api = f'{api}{match.group(1)}'
|
|
556
|
-
state = match.group(2)
|
|
557
|
-
return api, state
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
def reorder_op_name_list(op_name_list):
|
|
561
|
-
if not op_name_list:
|
|
562
|
-
return op_name_list
|
|
563
|
-
|
|
564
|
-
parameters = []
|
|
565
|
-
output = []
|
|
566
|
-
parameters_grad = []
|
|
567
|
-
others = []
|
|
568
|
-
for x in op_name_list:
|
|
569
|
-
state = get_name_and_state(x)[1]
|
|
570
|
-
if state == Const.PARAMS:
|
|
571
|
-
parameters.append(x)
|
|
572
|
-
elif state == Const.OUTPUT:
|
|
573
|
-
output.append(x)
|
|
574
|
-
elif state == Const.PARAMS_GRAD:
|
|
575
|
-
parameters_grad.append(x)
|
|
533
|
+
for row in result:
|
|
534
|
+
del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
|
|
535
|
+
header.append(CompareConst.DATA_NAME)
|
|
576
536
|
else:
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
return
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
584
|
-
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
585
|
-
if not op_name_list or not summary_list:
|
|
586
|
-
return op_name_list, summary_list, data_name_list
|
|
587
|
-
|
|
588
|
-
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
589
|
-
|
|
590
|
-
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
591
|
-
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
592
|
-
if data_name_list:
|
|
593
|
-
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
594
|
-
else:
|
|
595
|
-
data_name_reorder = data_name_list
|
|
596
|
-
|
|
597
|
-
return op_name_reorder, summary_reorder, data_name_reorder
|
|
537
|
+
for row in result:
|
|
538
|
+
del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
|
|
539
|
+
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
540
|
+
return result_df
|
|
598
541
|
|
|
599
542
|
|
|
600
543
|
def _compare_parser(parser):
|
|
@@ -617,3 +560,34 @@ def _compare_parser(parser):
|
|
|
617
560
|
help="<optional> The data mapping file path.", required=False)
|
|
618
561
|
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
619
562
|
help="<optional> The layer mapping file path.", required=False)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs):
|
|
566
|
+
if kwargs.get('suffix'):
|
|
567
|
+
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
568
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
569
|
+
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
570
|
+
# get the ranks and match by order
|
|
571
|
+
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
572
|
+
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
573
|
+
if len(npu_ranks) != len(bench_ranks):
|
|
574
|
+
logger.error('The number of ranks in the two runs are different. '
|
|
575
|
+
'Unable to match the ranks. Please use another folder to compare '
|
|
576
|
+
'or use compare() api and manually match the ranks.')
|
|
577
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
578
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
579
|
+
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
580
|
+
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
581
|
+
for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]:
|
|
582
|
+
npu_path = extract_json(npu_data_dir, file_type)
|
|
583
|
+
bench_path = extract_json(bench_data_dir, file_type)
|
|
584
|
+
if npu_path == "" or bench_path == "":
|
|
585
|
+
logger.debug(f'Did not find paired {file_type} in {npu_data_dir} and {bench_data_dir},'
|
|
586
|
+
' skip comparing.')
|
|
587
|
+
continue
|
|
588
|
+
dump_result_param = {
|
|
589
|
+
'npu_json_path': npu_path,
|
|
590
|
+
'bench_json_path': bench_path,
|
|
591
|
+
'is_print_compare_log': is_print_compare_log
|
|
592
|
+
}
|
|
593
|
+
compare_func(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs)
|
|
@@ -13,7 +13,5 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
rank_id: int = -1
|
|
19
|
-
is_running: bool = False
|
|
16
|
+
import msprobe.core.config_check.checkers
|
|
17
|
+
from msprobe.core.config_check.config_checker import ConfigChecker
|
msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py}
RENAMED
|
@@ -13,21 +13,13 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
__all__ = ['BaseChecker', 'apply_patches']
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
import msprobe.core.config_check.checkers.env_args_checker
|
|
19
|
+
import msprobe.core.config_check.checkers.pip_checker
|
|
20
|
+
import msprobe.core.config_check.checkers.dataset_checker
|
|
21
|
+
import msprobe.core.config_check.checkers.weights_checker
|
|
22
|
+
import msprobe.core.config_check.checkers.hyperparameter_checker
|
|
23
|
+
import msprobe.core.config_check.checkers.random_checker
|
|
19
24
|
|
|
20
|
-
|
|
21
|
-
def create_kernel_config_json(dump_path, cur_rank):
|
|
22
|
-
kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
|
|
23
|
-
kernel_config_path = os.path.join(dump_path, kernel_config_name)
|
|
24
|
-
config_info = {
|
|
25
|
-
"dump": {
|
|
26
|
-
"dump_list": [],
|
|
27
|
-
"dump_path": dump_path,
|
|
28
|
-
"dump_mode": "all",
|
|
29
|
-
"dump_op_switch": "on"
|
|
30
|
-
}
|
|
31
|
-
}
|
|
32
|
-
save_json(kernel_config_path, config_info, indent=4)
|
|
33
|
-
return kernel_config_path
|
|
25
|
+
from msprobe.core.config_check.checkers.base_checker import BaseChecker
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.framework_adapter import FmkAdp
|
|
19
|
+
from msprobe.core.common.const import FileCheckConst
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PackInput:
|
|
23
|
+
|
|
24
|
+
def __init__(self, output_zip_path, model, shell_path):
|
|
25
|
+
self.output_zip_path = output_zip_path
|
|
26
|
+
self.shell_path = shell_path
|
|
27
|
+
self.model = model[0] if isinstance(model, list) and len(model) > 0 else model
|
|
28
|
+
self.check_input_params()
|
|
29
|
+
|
|
30
|
+
def check_input_params(self):
|
|
31
|
+
if self.model and not FmkAdp.is_nn_module(self.model):
|
|
32
|
+
raise Exception(f"model is not torch.nn.Module/mindspore.nn.Cell or module list.")
|
|
33
|
+
if not isinstance(self.output_zip_path, str) or not self.output_zip_path.endswith(FileCheckConst.ZIP_SUFFIX):
|
|
34
|
+
raise Exception(f"output zip path must be a string and ends with '.zip'")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BaseChecker:
|
|
38
|
+
input_needed = None
|
|
39
|
+
target_name_in_zip = None
|
|
40
|
+
multi_rank = False
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def pack(pack_input):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def compare(bench_dir, cmp_dir, output_path, fmk):
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def apply_patches(fmk):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def compare_ex(cls, bench_dir, cmp_dir, output_path, fmk):
|
|
56
|
+
bench_filepath = os.path.join(bench_dir, cls.target_name_in_zip)
|
|
57
|
+
cmp_filepath = os.path.join(cmp_dir, cls.target_name_in_zip)
|
|
58
|
+
if not os.path.exists(bench_filepath) or not os.path.exists(cmp_filepath):
|
|
59
|
+
return None, None, None
|
|
60
|
+
return cls.compare(bench_dir, cmp_dir, output_path, fmk)
|