mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,20 +1,35 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import copy
|
|
2
|
-
import csv
|
|
3
17
|
import glob
|
|
4
18
|
import os
|
|
19
|
+
import re
|
|
5
20
|
|
|
6
21
|
import numpy as np
|
|
7
22
|
import pandas as pd
|
|
8
|
-
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
9
|
-
from msprobe.core.common.file_utils import
|
|
23
|
+
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
24
|
+
from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
|
|
10
25
|
from msprobe.core.common.log import logger
|
|
11
26
|
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
12
27
|
from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
|
|
13
|
-
from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check,
|
|
28
|
+
from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, compare_ops_apply
|
|
14
29
|
from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
|
|
15
30
|
|
|
16
31
|
|
|
17
|
-
class
|
|
32
|
+
class RowData:
|
|
18
33
|
def __init__(self, mode):
|
|
19
34
|
self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
|
|
20
35
|
self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
|
|
@@ -28,17 +43,34 @@ class row_data:
|
|
|
28
43
|
return self.data
|
|
29
44
|
|
|
30
45
|
|
|
46
|
+
def get_name_dict(name: str) -> dict:
|
|
47
|
+
compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
|
|
48
|
+
r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
|
|
49
|
+
match = compare_pattern.match(name)
|
|
50
|
+
if match:
|
|
51
|
+
return {'op_type': match.group(1),
|
|
52
|
+
'op_name': match.group(2),
|
|
53
|
+
'task_id': match.group(3),
|
|
54
|
+
'stream_id': match.group(4),
|
|
55
|
+
'timestamp': match.group(5).split(Const.SEP)[0],
|
|
56
|
+
'input_output_index': match.group(6),
|
|
57
|
+
'slot': match.group(7),
|
|
58
|
+
'format': match.group(8)}
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
|
|
31
62
|
def npy_data_read(data_path, npy_file_list, mapping_dict):
|
|
32
63
|
data_list = []
|
|
64
|
+
compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
|
|
33
65
|
for data in npy_file_list:
|
|
34
66
|
if data in mapping_dict:
|
|
35
|
-
|
|
67
|
+
name_dict = get_name_dict(mapping_dict[data])
|
|
36
68
|
else:
|
|
37
|
-
|
|
38
|
-
if
|
|
69
|
+
name_dict = get_name_dict(data)
|
|
70
|
+
if not name_dict:
|
|
39
71
|
continue
|
|
40
|
-
compare_key =
|
|
41
|
-
timestamp = convert_to_int(
|
|
72
|
+
compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
|
|
73
|
+
timestamp = convert_to_int(name_dict.get('timestamp'))
|
|
42
74
|
|
|
43
75
|
data_list.append([os.path.join(data_path, data), compare_key, timestamp])
|
|
44
76
|
return data_list
|
|
@@ -48,18 +80,17 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
48
80
|
data_list = []
|
|
49
81
|
statistic_data_list = []
|
|
50
82
|
header_index = {
|
|
51
|
-
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
52
|
-
'Min Value': None,'Avg Value': None, 'L2Norm Value': None
|
|
83
|
+
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
84
|
+
'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
|
|
53
85
|
}
|
|
54
86
|
for statistic_file in statistic_file_list:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
for
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
statistic_data_list.extend([row for row in csv_reader])
|
|
87
|
+
content = read_csv(statistic_file, as_pd=False)
|
|
88
|
+
header = content[0]
|
|
89
|
+
for key in header_index.keys():
|
|
90
|
+
for index, value in enumerate(header):
|
|
91
|
+
if key == value:
|
|
92
|
+
header_index[key] = index
|
|
93
|
+
statistic_data_list.extend(content[1:])
|
|
63
94
|
|
|
64
95
|
for key in header_index.keys():
|
|
65
96
|
if header_index[key] is None:
|
|
@@ -97,11 +128,9 @@ def generate_data_name(data_path):
|
|
|
97
128
|
mapping_dict = {}
|
|
98
129
|
if mapping_exist:
|
|
99
130
|
for mapping_file in mapping_file_list:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
for row in csv_reader:
|
|
104
|
-
mapping_dict[row[0]] = row[1]
|
|
131
|
+
content = read_csv(mapping_file, False)
|
|
132
|
+
for row in content[1:]:
|
|
133
|
+
mapping_dict[row[0]] = row[1]
|
|
105
134
|
|
|
106
135
|
if npy_exist:
|
|
107
136
|
data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
|
|
@@ -115,10 +144,16 @@ def generate_data_name(data_path):
|
|
|
115
144
|
mode = GraphMode.STATISTIC_MODE
|
|
116
145
|
else:
|
|
117
146
|
mode = GraphMode.ERROR_MODE
|
|
118
|
-
logger.error(
|
|
147
|
+
logger.error("Error mode.")
|
|
119
148
|
return mode, data_list
|
|
120
149
|
|
|
121
150
|
|
|
151
|
+
def transform_special_string_into_float(data_frame):
|
|
152
|
+
data_frame[data_frame == "null"] = '0'
|
|
153
|
+
data_frame[data_frame == "False"] = '0'
|
|
154
|
+
data_frame[data_frame == "True"] = '1'
|
|
155
|
+
|
|
156
|
+
|
|
122
157
|
class GraphMSComparator:
|
|
123
158
|
def __init__(self, input_param, output_path):
|
|
124
159
|
self.output_path = output_path
|
|
@@ -136,7 +171,7 @@ class GraphMSComparator:
|
|
|
136
171
|
def compare_ops(compare_result_db, mode):
|
|
137
172
|
|
|
138
173
|
def npy_mode_compute(row):
|
|
139
|
-
result_dict =
|
|
174
|
+
result_dict = RowData(GraphMode.NPY_MODE)()
|
|
140
175
|
|
|
141
176
|
def process_npy_file(file_path, name_prefix, result):
|
|
142
177
|
if os.path.exists(file_path):
|
|
@@ -158,7 +193,6 @@ class GraphMSComparator:
|
|
|
158
193
|
result_dict[CompareConst.ERROR_MESSAGE] = error_message
|
|
159
194
|
|
|
160
195
|
if not error_flag:
|
|
161
|
-
n_value, b_value = reshape_value(n_value, b_value)
|
|
162
196
|
result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
|
|
163
197
|
result_dict[CompareConst.COSINE] = result_list[0]
|
|
164
198
|
result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
|
|
@@ -171,7 +205,7 @@ class GraphMSComparator:
|
|
|
171
205
|
return pd.Series(result_dict)
|
|
172
206
|
|
|
173
207
|
def statistic_mode_compute(row):
|
|
174
|
-
result_dict =
|
|
208
|
+
result_dict = RowData('STATISTIC')()
|
|
175
209
|
|
|
176
210
|
def update_result_dict(result, rows, prefix):
|
|
177
211
|
result[f'{prefix} Name'] = rows[f'{prefix} Name']
|
|
@@ -198,24 +232,30 @@ class GraphMSComparator:
|
|
|
198
232
|
result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
|
|
199
233
|
result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
|
|
200
234
|
CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
|
|
201
|
-
|
|
235
|
+
if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
|
|
236
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
|
|
237
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
|
|
202
238
|
result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
|
|
203
239
|
CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
|
|
204
|
-
|
|
240
|
+
if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
|
|
241
|
+
result_dict[CompareConst.MIN_RELATIVE_ERR] = \
|
|
242
|
+
str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
|
|
205
243
|
result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
|
|
206
244
|
CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
|
|
207
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
208
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
245
|
+
if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
|
|
246
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
|
|
247
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
|
|
209
248
|
result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
|
|
210
249
|
CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
|
|
211
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
212
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
250
|
+
if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
|
|
251
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
|
|
252
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
|
|
213
253
|
magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
|
|
214
254
|
max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
|
|
215
|
-
if
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
255
|
+
if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
|
|
256
|
+
magnitude_diff = 0
|
|
257
|
+
result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
|
|
258
|
+
magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
|
|
219
259
|
|
|
220
260
|
return pd.Series(result_dict)
|
|
221
261
|
|
|
@@ -238,24 +278,23 @@ class GraphMSComparator:
|
|
|
238
278
|
is_empty = True
|
|
239
279
|
if is_empty or not mode:
|
|
240
280
|
continue
|
|
241
|
-
compare_result_df = self.
|
|
281
|
+
compare_result_df = self.do_multi_process(compare_result_df, mode)
|
|
242
282
|
compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
|
|
243
283
|
compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
|
|
244
|
-
check_path_before_create(compare_result_path)
|
|
245
284
|
self.to_excel(compare_result_df, compare_result_path)
|
|
246
285
|
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
247
|
-
|
|
286
|
+
|
|
248
287
|
def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
|
|
249
288
|
size = len(compare_result_df)
|
|
250
289
|
# sheet size cannot be larger than 1048576
|
|
251
290
|
if size < CompareConst.MAX_EXCEL_LENGTH:
|
|
252
|
-
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if
|
|
253
|
-
|
|
254
|
-
|
|
291
|
+
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
|
|
292
|
+
need_slice else compare_result_path
|
|
293
|
+
save_excel(compare_result_path, compare_result_df)
|
|
255
294
|
return slice_num + 1
|
|
256
295
|
else:
|
|
257
|
-
slice_num = self.to_excel(compare_result_df.iloc[0: size//2], compare_result_path, slice_num, True)
|
|
258
|
-
return self.to_excel(compare_result_df.iloc[size//2:], compare_result_path, slice_num, True)
|
|
296
|
+
slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
|
|
297
|
+
return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
|
|
259
298
|
|
|
260
299
|
def compare_process(self, rank_id, step_id):
|
|
261
300
|
# generate data_path
|
|
@@ -300,13 +339,17 @@ class GraphMSComparator:
|
|
|
300
339
|
CompareConst.BENCH_NORM])
|
|
301
340
|
|
|
302
341
|
npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
303
|
-
|
|
342
|
+
npu_float_data_df = npu_data_df[npu_float_type].astype(str)
|
|
343
|
+
transform_special_string_into_float(npu_float_data_df)
|
|
344
|
+
npu_data_df[npu_float_type] = npu_float_data_df.astype(float)
|
|
304
345
|
|
|
305
346
|
bench_float_type = [
|
|
306
|
-
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
307
|
-
CompareConst.BENCH_MEAN,CompareConst.BENCH_NORM
|
|
347
|
+
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
348
|
+
CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
|
|
308
349
|
]
|
|
309
|
-
|
|
350
|
+
bench_float_data_df = bench_data_df[bench_float_type].astype(str)
|
|
351
|
+
transform_special_string_into_float(bench_float_data_df)
|
|
352
|
+
bench_data_df[bench_float_type] = bench_float_data_df.astype(float)
|
|
310
353
|
|
|
311
354
|
npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
312
355
|
bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
@@ -355,7 +398,7 @@ class GraphMSComparator:
|
|
|
355
398
|
rank_step_path_dict[rank_step_key] = [dir_path]
|
|
356
399
|
return dict(sorted(rank_step_path_dict.items()))
|
|
357
400
|
|
|
358
|
-
def
|
|
401
|
+
def do_multi_process(self, result_df, mode):
|
|
359
402
|
try:
|
|
360
403
|
result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
|
|
361
404
|
except ValueError as e:
|
|
@@ -33,12 +33,13 @@ class DebuggerConfig:
|
|
|
33
33
|
self.level_ori = common_config.level
|
|
34
34
|
self.list = [] if not task_config.list else task_config.list
|
|
35
35
|
self.scope = [] if not task_config.scope else task_config.scope
|
|
36
|
-
self.data_mode = [] if not task_config.data_mode else task_config.data_mode
|
|
36
|
+
self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
|
|
37
37
|
self.file_format = task_config.file_format
|
|
38
38
|
self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
|
|
39
39
|
self.check_mode = task_config.check_mode
|
|
40
40
|
self.framework = Const.MS_FRAMEWORK
|
|
41
41
|
self.summary_mode = task_config.summary_mode
|
|
42
|
+
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
42
43
|
self.check()
|
|
43
44
|
create_directory(self.dump_path)
|
|
44
45
|
|
|
@@ -52,6 +53,9 @@ class DebuggerConfig:
|
|
|
52
53
|
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
53
54
|
raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
54
55
|
f"but got {self.pert_type}.")
|
|
56
|
+
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
57
|
+
raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
|
|
58
|
+
f"but got {self.handler_type}.")
|
|
55
59
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
56
60
|
|
|
57
61
|
def check(self):
|
|
@@ -66,4 +70,6 @@ class DebuggerConfig:
|
|
|
66
70
|
self.file_format = "npy"
|
|
67
71
|
if not self.check_mode:
|
|
68
72
|
self.check_mode = "all"
|
|
73
|
+
if not isinstance(self.async_dump, bool):
|
|
74
|
+
raise Exception("The parameters async_dump should be bool.")
|
|
69
75
|
return True
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -14,25 +14,42 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from collections import defaultdict, namedtuple
|
|
17
18
|
|
|
18
19
|
import mindspore as ms
|
|
19
20
|
from mindspore._c_expression import MSContext
|
|
20
21
|
|
|
21
|
-
from msprobe.core.common.const import Const, MsgConst
|
|
22
|
+
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker
|
|
25
|
+
from msprobe.core.common.utils import get_real_step_or_rank
|
|
26
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
22
27
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
28
|
+
from msprobe.mindspore.common.utils import set_register_backward_hook_functions
|
|
23
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
31
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
24
32
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
25
33
|
from msprobe.mindspore.ms_config import parse_json_config
|
|
26
34
|
from msprobe.mindspore.runtime import Runtime
|
|
27
35
|
from msprobe.mindspore.service import Service
|
|
28
36
|
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
29
37
|
|
|
38
|
+
try:
|
|
39
|
+
from msprobe.lib import _msprobe_c
|
|
40
|
+
except ImportError:
|
|
41
|
+
_msprobe_c = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"])
|
|
45
|
+
|
|
30
46
|
|
|
31
47
|
class PrecisionDebugger:
|
|
32
48
|
_instance = None
|
|
33
49
|
task_not_need_service = [Const.GRAD_PROBE]
|
|
34
50
|
|
|
35
|
-
def __new__(cls, config_path=None,
|
|
51
|
+
def __new__(cls, config_path=None, task=None, dump_path=None,
|
|
52
|
+
level=None, step=None, opt=None):
|
|
36
53
|
if not cls._instance:
|
|
37
54
|
cls._instance = super().__new__(cls)
|
|
38
55
|
cls._instance.initialized = False
|
|
@@ -41,22 +58,65 @@ class PrecisionDebugger:
|
|
|
41
58
|
cls.first_start = False
|
|
42
59
|
return cls._instance
|
|
43
60
|
|
|
44
|
-
def __init__(self, config_path=None
|
|
61
|
+
def __init__(self, config_path=None, task=None, dump_path=None,
|
|
62
|
+
level=None, step=None):
|
|
45
63
|
if self.initialized:
|
|
46
64
|
return
|
|
47
65
|
self.initialized = True
|
|
66
|
+
|
|
67
|
+
set_register_backward_hook_functions()
|
|
68
|
+
|
|
48
69
|
if not config_path:
|
|
49
70
|
config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
|
|
71
|
+
|
|
72
|
+
config_params = ConfigParameters(config_path, task, dump_path, level)
|
|
73
|
+
self.check_input_params(config_params)
|
|
74
|
+
|
|
50
75
|
common_config, task_config = parse_json_config(config_path)
|
|
76
|
+
common_config.task = task if task else common_config.task
|
|
51
77
|
self.task = common_config.task
|
|
52
78
|
if self.task == Const.GRAD_PROBE:
|
|
53
79
|
self.gm = GradientMonitor(common_config, task_config)
|
|
54
80
|
return
|
|
81
|
+
common_config.step = get_real_step_or_rank(
|
|
82
|
+
step, Const.STEP) if step is not None else common_config.step
|
|
83
|
+
common_config.level = level if level else common_config.level
|
|
84
|
+
common_config.dump_path = dump_path if dump_path else common_config.dump_path
|
|
55
85
|
self.config = DebuggerConfig(common_config, task_config)
|
|
56
86
|
|
|
87
|
+
if _msprobe_c:
|
|
88
|
+
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
89
|
+
|
|
90
|
+
self.config.execution_mode = self._get_execution_mode()
|
|
91
|
+
if self._need_service():
|
|
92
|
+
self.service = Service(self.config)
|
|
93
|
+
|
|
57
94
|
Runtime.step_count = 0
|
|
58
95
|
Runtime.is_running = False
|
|
59
96
|
|
|
97
|
+
@staticmethod
|
|
98
|
+
def check_input_params(args):
|
|
99
|
+
if args.config_path is not None:
|
|
100
|
+
if not isinstance(args.config_path, str):
|
|
101
|
+
raise MsprobeException(
|
|
102
|
+
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
103
|
+
file_checker = FileChecker(
|
|
104
|
+
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
105
|
+
file_checker.common_check()
|
|
106
|
+
|
|
107
|
+
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
108
|
+
raise MsprobeException(
|
|
109
|
+
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
110
|
+
|
|
111
|
+
if args.dump_path is not None:
|
|
112
|
+
if not isinstance(args.dump_path, str):
|
|
113
|
+
raise MsprobeException(
|
|
114
|
+
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
115
|
+
|
|
116
|
+
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
117
|
+
raise MsprobeException(
|
|
118
|
+
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
119
|
+
|
|
60
120
|
@staticmethod
|
|
61
121
|
def _get_execution_mode():
|
|
62
122
|
jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL)
|
|
@@ -75,11 +135,23 @@ class PrecisionDebugger:
|
|
|
75
135
|
else:
|
|
76
136
|
return MsConst.PYNATIVE_MODE
|
|
77
137
|
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _is_graph_dump(config):
|
|
140
|
+
if config.level != MsConst.KERNEL:
|
|
141
|
+
return False
|
|
142
|
+
if not config.list or len(config.list) > 1:
|
|
143
|
+
return True
|
|
144
|
+
if '-' in config.list[0] or '/' in config.list[0]:
|
|
145
|
+
return True
|
|
146
|
+
return False
|
|
147
|
+
|
|
78
148
|
@classmethod
|
|
79
149
|
def start(cls, model=None):
|
|
80
150
|
instance = cls._instance
|
|
81
151
|
if not instance:
|
|
82
152
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
153
|
+
if _msprobe_c:
|
|
154
|
+
_msprobe_c._PrecisionDebugger().start()
|
|
83
155
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
84
156
|
return
|
|
85
157
|
|
|
@@ -90,6 +162,7 @@ class PrecisionDebugger:
|
|
|
90
162
|
instance.service.start(model)
|
|
91
163
|
else:
|
|
92
164
|
if not instance.first_start:
|
|
165
|
+
api_register.api_set_ori_func()
|
|
93
166
|
handler = TaskHandlerFactory.create(instance.config)
|
|
94
167
|
handler.handle()
|
|
95
168
|
|
|
@@ -99,18 +172,15 @@ class PrecisionDebugger:
|
|
|
99
172
|
@classmethod
|
|
100
173
|
def forward_backward_dump_end(cls):
|
|
101
174
|
instance = cls._instance
|
|
102
|
-
|
|
103
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
104
|
-
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
105
|
-
return
|
|
106
|
-
if instance.service:
|
|
107
|
-
instance.service.forward_backward_dump_end()
|
|
175
|
+
instance.stop()
|
|
108
176
|
|
|
109
177
|
@classmethod
|
|
110
178
|
def stop(cls):
|
|
111
179
|
instance = cls._instance
|
|
112
180
|
if not instance:
|
|
113
181
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
182
|
+
if _msprobe_c:
|
|
183
|
+
_msprobe_c._PrecisionDebugger().stop()
|
|
114
184
|
if instance.task == Const.GRAD_PROBE:
|
|
115
185
|
instance.gm.stop()
|
|
116
186
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
@@ -124,10 +194,15 @@ class PrecisionDebugger:
|
|
|
124
194
|
instance = cls._instance
|
|
125
195
|
if not instance:
|
|
126
196
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
197
|
+
if _msprobe_c:
|
|
198
|
+
_msprobe_c._PrecisionDebugger().step()
|
|
127
199
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
128
200
|
return
|
|
129
201
|
if instance.service:
|
|
130
202
|
instance.service.step()
|
|
203
|
+
HOOKCell.cell_count = defaultdict(int)
|
|
204
|
+
CellProcessor.reset_cell_stats()
|
|
205
|
+
|
|
131
206
|
Runtime.step_count += 1
|
|
132
207
|
|
|
133
208
|
@classmethod
|
|
@@ -147,4 +222,4 @@ class PrecisionDebugger:
|
|
|
147
222
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
148
223
|
return False
|
|
149
224
|
else:
|
|
150
|
-
return instance.config.task != Const.FREE_BENCHMARK and instance.config
|
|
225
|
+
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -40,6 +40,8 @@ class DumpToolFactory:
|
|
|
40
40
|
|
|
41
41
|
@staticmethod
|
|
42
42
|
def create(config: DebuggerConfig):
|
|
43
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
44
|
+
raise Exception("data_mode must be one of all, input, output.")
|
|
43
45
|
tool = DumpToolFactory.tools.get(config.level)
|
|
44
46
|
if not tool:
|
|
45
47
|
raise Exception("Valid level is needed.")
|