mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/core/common/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,14 +12,17 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import collections
|
|
18
17
|
import os
|
|
19
18
|
import re
|
|
20
19
|
import subprocess
|
|
21
20
|
import time
|
|
22
|
-
import
|
|
21
|
+
from collections import defaultdict
|
|
23
22
|
from datetime import datetime, timezone
|
|
23
|
+
from functools import wraps
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
24
26
|
|
|
25
27
|
from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
|
|
26
28
|
from msprobe.core.common.const import Const, CompareConst
|
|
@@ -68,6 +70,11 @@ class MsprobeBaseException(Exception):
|
|
|
68
70
|
FUNCTION_CALL_ERROR = 28
|
|
69
71
|
FORWARD_DATA_COLLECTION_ERROR = 29
|
|
70
72
|
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
73
|
+
INVALID_KEY_ERROR = 31
|
|
74
|
+
MISSING_HEADER_ERROR = 32
|
|
75
|
+
MERGE_COMPARE_RESULT_ERROR = 33
|
|
76
|
+
NAMES_STRUCTS_MATCH_ERROR = 34
|
|
77
|
+
INVALID_STATE_ERROR = 35
|
|
71
78
|
|
|
72
79
|
def __init__(self, code, error_info: str = ""):
|
|
73
80
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -99,7 +106,14 @@ class DumpException(MsprobeBaseException):
|
|
|
99
106
|
return f"Dump Error Code {self.code}: {self.error_info}"
|
|
100
107
|
|
|
101
108
|
|
|
102
|
-
def
|
|
109
|
+
def is_json_file(file_path):
|
|
110
|
+
if isinstance(file_path, str) and file_path.lower().endswith('.json'):
|
|
111
|
+
return True
|
|
112
|
+
else:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def check_compare_param(input_param, output_path, dump_mode, stack_mode):
|
|
103
117
|
if not isinstance(input_param, dict):
|
|
104
118
|
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
105
119
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -107,18 +121,31 @@ def check_compare_param(input_param, output_path, summary_compare=False, md5_com
|
|
|
107
121
|
logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
|
|
108
122
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
109
123
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
124
|
+
def check_json_path(json_path_str):
|
|
125
|
+
json_path = input_param.get(json_path_str)
|
|
126
|
+
check_file_or_directory_path(json_path, False)
|
|
127
|
+
json_type_check = is_json_file(json_path)
|
|
128
|
+
if not json_type_check:
|
|
129
|
+
logger.error(f"Invalid {json_path_str}: {json_path}, please check!")
|
|
130
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
131
|
+
|
|
132
|
+
check_json_path("npu_json_path")
|
|
133
|
+
check_json_path("bench_json_path")
|
|
134
|
+
if stack_mode:
|
|
135
|
+
check_json_path("stack_json_path")
|
|
136
|
+
|
|
137
|
+
if dump_mode == Const.ALL:
|
|
114
138
|
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
115
139
|
check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
|
|
116
140
|
check_file_or_directory_path(output_path, True)
|
|
117
141
|
|
|
118
142
|
with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
|
|
119
|
-
FileOpen(input_param.get("bench_json_path"), "r") as bench_json
|
|
120
|
-
|
|
121
|
-
|
|
143
|
+
FileOpen(input_param.get("bench_json_path"), "r") as bench_json:
|
|
144
|
+
_check_json(npu_json, input_param.get("npu_json_path"))
|
|
145
|
+
_check_json(bench_json, input_param.get("bench_json_path"))
|
|
146
|
+
if stack_mode:
|
|
147
|
+
with FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
|
|
148
|
+
_check_json(stack_json, input_param.get("stack_json_path"))
|
|
122
149
|
|
|
123
150
|
|
|
124
151
|
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
@@ -179,7 +206,7 @@ def execute_command(cmd):
|
|
|
179
206
|
line = process.stdout.readline()
|
|
180
207
|
line = line.strip()
|
|
181
208
|
if line:
|
|
182
|
-
|
|
209
|
+
logger.info(line)
|
|
183
210
|
if process.returncode != 0:
|
|
184
211
|
logger.error('Failed to execute command:%s' % " ".join(cmd))
|
|
185
212
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
@@ -212,25 +239,29 @@ def md5_find(data):
|
|
|
212
239
|
for data_detail in data[key_op][api_info]:
|
|
213
240
|
if data_detail and 'md5' in data_detail:
|
|
214
241
|
return True
|
|
215
|
-
elif 'md5' in data[key_op][api_info]:
|
|
242
|
+
elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
|
|
216
243
|
return True
|
|
217
244
|
return False
|
|
218
245
|
|
|
219
246
|
|
|
220
|
-
def
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
247
|
+
def detect_framework_by_dump_json(file_path):
|
|
248
|
+
pattern_ms = r'"type":\s*"mindspore'
|
|
249
|
+
pattern_pt = r'"type":\s*"torch'
|
|
250
|
+
with FileOpen(file_path, 'r') as file:
|
|
251
|
+
for line in file:
|
|
252
|
+
if re.search(pattern_ms, line):
|
|
253
|
+
return Const.MS_FRAMEWORK
|
|
254
|
+
if re.search(pattern_pt, line):
|
|
255
|
+
return Const.PT_FRAMEWORK
|
|
256
|
+
logger.error(f"{file_path} must be based on the MindSpore or PyTorch framework.")
|
|
257
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
228
258
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
259
|
+
|
|
260
|
+
def get_stack_construct_by_dump_json_path(dump_json_path):
|
|
261
|
+
if not dump_json_path:
|
|
262
|
+
logger.error("The path is empty. Please enter a valid path.")
|
|
232
263
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
233
|
-
directory = os.path.dirname(
|
|
264
|
+
directory = os.path.dirname(dump_json_path)
|
|
234
265
|
check_file_or_directory_path(directory, True)
|
|
235
266
|
stack_json = os.path.join(directory, "stack.json")
|
|
236
267
|
construct_json = os.path.join(directory, "construct.json")
|
|
@@ -240,41 +271,57 @@ def struct_json_get(input_param, framework):
|
|
|
240
271
|
return stack, construct
|
|
241
272
|
|
|
242
273
|
|
|
243
|
-
def
|
|
274
|
+
def set_dump_path(input_param):
|
|
244
275
|
npu_path = input_param.get("npu_json_path", None)
|
|
245
276
|
bench_path = input_param.get("bench_json_path", None)
|
|
246
|
-
|
|
247
|
-
|
|
277
|
+
npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
|
|
278
|
+
bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
|
|
279
|
+
if not npu_path_valid or not bench_path_valid:
|
|
280
|
+
logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}")
|
|
248
281
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
282
|
+
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
283
|
+
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def get_dump_mode(input_param):
|
|
287
|
+
npu_path = input_param.get("npu_json_path", None)
|
|
288
|
+
bench_path = input_param.get("bench_json_path", None)
|
|
289
|
+
npu_json_data = load_json(npu_path)
|
|
290
|
+
bench_json_data = load_json(bench_path)
|
|
291
|
+
|
|
292
|
+
npu_task = npu_json_data.get('task', None)
|
|
293
|
+
bench_task = bench_json_data.get('task', None)
|
|
294
|
+
|
|
295
|
+
if not npu_task or not bench_task:
|
|
296
|
+
logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
|
|
297
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
298
|
+
|
|
299
|
+
if npu_task != bench_task:
|
|
254
300
|
logger.error(f"Please check the dump task is consistent.")
|
|
255
301
|
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
302
|
+
|
|
303
|
+
if npu_task == Const.TENSOR:
|
|
304
|
+
return Const.ALL
|
|
305
|
+
|
|
306
|
+
if npu_task == Const.STATISTICS:
|
|
307
|
+
npu_md5_compare = md5_find(npu_json_data['data'])
|
|
308
|
+
bench_md5_compare = md5_find(bench_json_data['data'])
|
|
309
|
+
if npu_md5_compare == bench_md5_compare:
|
|
310
|
+
return Const.MD5 if npu_md5_compare else Const.SUMMARY
|
|
263
311
|
else:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
268
|
-
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
269
|
-
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
270
|
-
return summary_compare, md5_compare
|
|
312
|
+
logger.error(f"Please check the dump task is consistent, "
|
|
313
|
+
f"dump mode of npu and bench should both be statistics or md5.")
|
|
314
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
271
315
|
|
|
316
|
+
logger.error(f"Compare applies only to task is tensor or statistics")
|
|
317
|
+
raise CompareException(CompareException.INVALID_TASK_ERROR)
|
|
272
318
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
319
|
+
|
|
320
|
+
def get_header_index(header_name, dump_mode):
|
|
321
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE.get(dump_mode)
|
|
322
|
+
if not header:
|
|
323
|
+
logger.error(f"{dump_mode} not in {CompareConst.HEAD_OF_COMPARE_MODE}")
|
|
324
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
278
325
|
if header_name not in header:
|
|
279
326
|
logger.error(f"{header_name} not in data name")
|
|
280
327
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -282,7 +329,7 @@ def get_header_index(header_name, summary_compare=False):
|
|
|
282
329
|
|
|
283
330
|
|
|
284
331
|
def convert_tuple(data):
|
|
285
|
-
return data if isinstance(data, tuple) else (data,
|
|
332
|
+
return data if isinstance(data, tuple) else (data,)
|
|
286
333
|
|
|
287
334
|
|
|
288
335
|
def check_op_str_pattern_valid(string, op_name=None, stack=False):
|
|
@@ -302,6 +349,10 @@ def is_invalid_pattern(string):
|
|
|
302
349
|
return re.search(pattern, string)
|
|
303
350
|
|
|
304
351
|
|
|
352
|
+
def is_int(x):
|
|
353
|
+
return isinstance(x, int) and not isinstance(x, bool)
|
|
354
|
+
|
|
355
|
+
|
|
305
356
|
def print_tools_ends_info():
|
|
306
357
|
total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
307
358
|
logger.info('*' * total_len)
|
|
@@ -315,51 +366,61 @@ def get_step_or_rank_from_string(step_or_rank, obj):
|
|
|
315
366
|
try:
|
|
316
367
|
borderlines = int(splited[0]), int(splited[1])
|
|
317
368
|
except (ValueError, IndexError) as e:
|
|
318
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
369
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
319
370
|
"The hyphen(-) must start and end with decimal numbers.") from e
|
|
320
371
|
else:
|
|
321
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
322
|
-
f'The string parameter for {obj} only supports formats like "3-5".
|
|
323
|
-
|
|
372
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
373
|
+
f'The string parameter for {obj} only supports formats like "3-5". '
|
|
374
|
+
f'Now string parameter for {obj} is "{step_or_rank}".')
|
|
375
|
+
if all(Const.STEP_RANK_MINIMUM_VALUE <= b <= Const.STEP_RANK_MAXIMUM_VALUE for b in borderlines):
|
|
324
376
|
if borderlines[0] <= borderlines[1]:
|
|
325
377
|
continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
|
|
326
378
|
else:
|
|
327
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
328
|
-
|
|
379
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
380
|
+
f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be '
|
|
381
|
+
f'greater than the right boundary ({borderlines[1]}).')
|
|
329
382
|
else:
|
|
330
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
331
|
-
f"The boundaries must fall within the range of
|
|
383
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
384
|
+
f"The boundaries must fall within the range of "
|
|
385
|
+
f"[{Const.STEP_RANK_MINIMUM_VALUE}, {Const.STEP_RANK_MAXIMUM_VALUE}].")
|
|
332
386
|
return continual_step_or_rank
|
|
333
387
|
|
|
334
388
|
|
|
335
389
|
def get_real_step_or_rank(step_or_rank_input, obj):
|
|
336
390
|
if obj not in [Const.STEP, Const.RANK]:
|
|
337
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
391
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
338
392
|
f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
|
|
339
393
|
if step_or_rank_input is None:
|
|
340
394
|
return []
|
|
341
395
|
if not isinstance(step_or_rank_input, list):
|
|
342
396
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
|
|
397
|
+
if len(step_or_rank_input) > Const.STEP_RANK_MAXIMUM_VALUE:
|
|
398
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
399
|
+
f"{obj} is invalid, its length cannot exceed {Const.STEP_RANK_MAXIMUM_VALUE}")
|
|
400
|
+
|
|
343
401
|
real_step_or_rank = []
|
|
344
402
|
for element in step_or_rank_input:
|
|
345
|
-
if not
|
|
346
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
403
|
+
if not is_int(element) and not isinstance(element, str):
|
|
404
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
347
405
|
f"{obj} element {element} must be an integer or string.")
|
|
348
|
-
if
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
406
|
+
if is_int(element):
|
|
407
|
+
if not Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
|
|
408
|
+
raise MsprobeException(
|
|
409
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
410
|
+
f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and "
|
|
411
|
+
f"{Const.STEP_RANK_MAXIMUM_VALUE}, currently it is {element}."
|
|
412
|
+
)
|
|
352
413
|
real_step_or_rank.append(element)
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
414
|
+
continue
|
|
415
|
+
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
416
|
+
real_step_or_rank.extend(continual_step_or_rank)
|
|
356
417
|
real_step_or_rank = list(set(real_step_or_rank))
|
|
357
418
|
real_step_or_rank.sort()
|
|
358
419
|
return real_step_or_rank
|
|
359
420
|
|
|
360
421
|
|
|
361
|
-
def check_seed_all(seed, mode):
|
|
362
|
-
if
|
|
422
|
+
def check_seed_all(seed, mode, rm_dropout):
|
|
423
|
+
if is_int(seed):
|
|
363
424
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
364
425
|
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
365
426
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
@@ -369,3 +430,69 @@ def check_seed_all(seed, mode):
|
|
|
369
430
|
if not isinstance(mode, bool):
|
|
370
431
|
logger.error("seed_all mode must be bool.")
|
|
371
432
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
433
|
+
if not isinstance(rm_dropout, bool):
|
|
434
|
+
logger.error("The rm_dropout parameter must be bool.")
|
|
435
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def safe_get_value(container, index, container_name, key=None):
|
|
439
|
+
try:
|
|
440
|
+
# 处理字典情况
|
|
441
|
+
if isinstance(container, dict):
|
|
442
|
+
return container.get(key)[index]
|
|
443
|
+
# 处理列表、元组、numpy情况
|
|
444
|
+
elif isinstance(container, (list, tuple, np.ndarray)):
|
|
445
|
+
return container[index]
|
|
446
|
+
else:
|
|
447
|
+
err_msg = f"Unsupported container type for '{container_name}': {type(container)}"
|
|
448
|
+
logger.error(err_msg)
|
|
449
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR)
|
|
450
|
+
except IndexError as e:
|
|
451
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
452
|
+
f"{container_name} is {container}\n" \
|
|
453
|
+
f"index is {index}"
|
|
454
|
+
logger.error(err_msg)
|
|
455
|
+
raise MsprobeBaseException(MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
456
|
+
except TypeError as e:
|
|
457
|
+
err_msg = "wrong type, please check!\n" \
|
|
458
|
+
f"{container_name} is {container}\n" \
|
|
459
|
+
f"index is {index}\n" \
|
|
460
|
+
f"key is {key}"
|
|
461
|
+
logger.error(err_msg)
|
|
462
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
# 记录工具函数递归的深度
|
|
466
|
+
recursion_depth = defaultdict(int)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
|
|
470
|
+
def recursion_depth_decorator(func_info):
|
|
471
|
+
def decorator(func):
|
|
472
|
+
@wraps(func)
|
|
473
|
+
def wrapper(*args, **kwargs):
|
|
474
|
+
func_id = id(func)
|
|
475
|
+
recursion_depth[func_id] += 1
|
|
476
|
+
if recursion_depth[func_id] > Const.MAX_DEPTH:
|
|
477
|
+
msg = f"call {func_info} exceeds the recursion limit."
|
|
478
|
+
logger.error_log_with_exp(
|
|
479
|
+
msg,
|
|
480
|
+
MsprobeException(
|
|
481
|
+
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
482
|
+
),
|
|
483
|
+
)
|
|
484
|
+
try:
|
|
485
|
+
result = func(*args, **kwargs)
|
|
486
|
+
finally:
|
|
487
|
+
recursion_depth[func_id] -= 1
|
|
488
|
+
return result
|
|
489
|
+
|
|
490
|
+
return wrapper
|
|
491
|
+
|
|
492
|
+
return decorator
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def check_str_param(param):
|
|
496
|
+
if not re.match(Const.REGEX_PREFIX_PATTERN, param):
|
|
497
|
+
logger.error('The parameter {} contains special characters.'.format(param))
|
|
498
|
+
raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR)
|
msprobe/core/common_config.py
CHANGED
|
@@ -1,7 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const, FileCheckConst
|
|
2
17
|
from msprobe.core.common.log import logger
|
|
3
18
|
from msprobe.core.common.exceptions import MsprobeException
|
|
4
|
-
from msprobe.core.common.file_utils import FileChecker
|
|
5
19
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
6
20
|
|
|
7
21
|
|
|
@@ -12,8 +26,8 @@ class CommonConfig:
|
|
|
12
26
|
self.rank = get_real_step_or_rank(json_config.get('rank'), Const.RANK)
|
|
13
27
|
self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
|
|
14
28
|
self.level = json_config.get('level')
|
|
15
|
-
self.acl_config = json_config.get('acl_config')
|
|
16
29
|
self.enable_dataloader = json_config.get('enable_dataloader', False)
|
|
30
|
+
self.async_dump = json_config.get("async_dump", False)
|
|
17
31
|
self._check_config()
|
|
18
32
|
|
|
19
33
|
def _check_config(self):
|
|
@@ -29,16 +43,11 @@ class CommonConfig:
|
|
|
29
43
|
if not isinstance(self.enable_dataloader, bool):
|
|
30
44
|
logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
|
|
31
45
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
32
|
-
if self.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _check_acl_config(self):
|
|
36
|
-
if not isinstance(self.acl_config, str):
|
|
37
|
-
logger.error_log_with_exp("acl_config is invalid, it should be a string",
|
|
46
|
+
if not isinstance(self.async_dump, bool):
|
|
47
|
+
logger.error_log_with_exp("async_dump is invalid, it should be a boolean",
|
|
38
48
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
file_checker.common_check()
|
|
49
|
+
elif self.async_dump:
|
|
50
|
+
logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
|
|
42
51
|
|
|
43
52
|
|
|
44
53
|
class BaseConfig:
|
|
@@ -46,7 +55,6 @@ class BaseConfig:
|
|
|
46
55
|
self.scope = json_config.get('scope')
|
|
47
56
|
self.list = json_config.get('list')
|
|
48
57
|
self.data_mode = json_config.get('data_mode')
|
|
49
|
-
self.backward_input = json_config.get("backward_input")
|
|
50
58
|
self.file_format = json_config.get("file_format")
|
|
51
59
|
self.summary_mode = json_config.get("summary_mode")
|
|
52
60
|
self.overflow_nums = json_config.get("overflow_nums")
|
|
@@ -74,5 +82,32 @@ class BaseConfig:
|
|
|
74
82
|
def check_config(self):
|
|
75
83
|
self._check_str_list_config(self.scope, "scope")
|
|
76
84
|
self._check_str_list_config(self.list, "list")
|
|
77
|
-
self.
|
|
78
|
-
|
|
85
|
+
self._check_data_mode()
|
|
86
|
+
|
|
87
|
+
def _check_data_mode(self):
|
|
88
|
+
if self.data_mode is not None:
|
|
89
|
+
if not isinstance(self.data_mode, list):
|
|
90
|
+
logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
|
|
91
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
92
|
+
|
|
93
|
+
if Const.ALL in self.data_mode and len(self.data_mode) != 1:
|
|
94
|
+
logger.error_log_with_exp(
|
|
95
|
+
"'all' cannot be combined with other options in data_mode.",
|
|
96
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if len(self.data_mode) >= len(Const.DUMP_DATA_MODE_LIST):
|
|
100
|
+
logger.error_log_with_exp(
|
|
101
|
+
f"The number of elements in the data_made cannot exceed {len(Const.DUMP_DATA_MODE_LIST) - 1}.",
|
|
102
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
for mode in self.data_mode:
|
|
106
|
+
if not isinstance(mode, str):
|
|
107
|
+
logger.error_log_with_exp("data_mode is invalid, it should be a list[str]",
|
|
108
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
109
|
+
if mode not in Const.DUMP_DATA_MODE_LIST:
|
|
110
|
+
logger.error_log_with_exp(
|
|
111
|
+
f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
|
|
112
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
113
|
+
)
|