mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,12 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import copy
|
|
2
|
-
import csv
|
|
3
17
|
import glob
|
|
4
18
|
import os
|
|
19
|
+
import re
|
|
5
20
|
|
|
6
21
|
import numpy as np
|
|
7
22
|
import pandas as pd
|
|
8
|
-
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
9
|
-
from msprobe.core.common.file_utils import
|
|
23
|
+
from msprobe.core.common.const import CompareConst, GraphMode, Const
|
|
24
|
+
from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
|
|
10
25
|
from msprobe.core.common.log import logger
|
|
11
26
|
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
12
27
|
from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
|
|
@@ -14,7 +29,7 @@ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_che
|
|
|
14
29
|
from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
|
|
15
30
|
|
|
16
31
|
|
|
17
|
-
class
|
|
32
|
+
class RowData:
|
|
18
33
|
def __init__(self, mode):
|
|
19
34
|
self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
|
|
20
35
|
self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
|
|
@@ -28,17 +43,34 @@ class row_data:
|
|
|
28
43
|
return self.data
|
|
29
44
|
|
|
30
45
|
|
|
46
|
+
def get_name_dict(name: str) -> dict:
|
|
47
|
+
compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
|
|
48
|
+
r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
|
|
49
|
+
match = compare_pattern.match(name)
|
|
50
|
+
if match:
|
|
51
|
+
return {'op_type': match.group(1),
|
|
52
|
+
'op_name': match.group(2),
|
|
53
|
+
'task_id': match.group(3),
|
|
54
|
+
'stream_id': match.group(4),
|
|
55
|
+
'timestamp': match.group(5).split(Const.SEP)[0],
|
|
56
|
+
'input_output_index': match.group(6),
|
|
57
|
+
'slot': match.group(7),
|
|
58
|
+
'format': match.group(8)}
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
|
|
31
62
|
def npy_data_read(data_path, npy_file_list, mapping_dict):
|
|
32
63
|
data_list = []
|
|
64
|
+
compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
|
|
33
65
|
for data in npy_file_list:
|
|
34
66
|
if data in mapping_dict:
|
|
35
|
-
|
|
67
|
+
name_dict = get_name_dict(mapping_dict[data])
|
|
36
68
|
else:
|
|
37
|
-
|
|
38
|
-
if
|
|
69
|
+
name_dict = get_name_dict(data)
|
|
70
|
+
if not name_dict:
|
|
39
71
|
continue
|
|
40
|
-
compare_key =
|
|
41
|
-
timestamp = convert_to_int(
|
|
72
|
+
compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
|
|
73
|
+
timestamp = convert_to_int(name_dict.get('timestamp'))
|
|
42
74
|
|
|
43
75
|
data_list.append([os.path.join(data_path, data), compare_key, timestamp])
|
|
44
76
|
return data_list
|
|
@@ -47,17 +79,18 @@ def npy_data_read(data_path, npy_file_list, mapping_dict):
|
|
|
47
79
|
def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
48
80
|
data_list = []
|
|
49
81
|
statistic_data_list = []
|
|
50
|
-
header_index = {
|
|
51
|
-
|
|
82
|
+
header_index = {
|
|
83
|
+
'Data Type': None, 'Shape': None, 'Max Value': None,
|
|
84
|
+
'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
|
|
85
|
+
}
|
|
52
86
|
for statistic_file in statistic_file_list:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
for
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
statistic_data_list.extend([row for row in csv_reader])
|
|
87
|
+
content = read_csv(statistic_file, as_pd=False)
|
|
88
|
+
header = content[0]
|
|
89
|
+
for key in header_index.keys():
|
|
90
|
+
for index, value in enumerate(header):
|
|
91
|
+
if key == value:
|
|
92
|
+
header_index[key] = index
|
|
93
|
+
statistic_data_list.extend(content[1:])
|
|
61
94
|
|
|
62
95
|
for key in header_index.keys():
|
|
63
96
|
if header_index[key] is None:
|
|
@@ -65,8 +98,9 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
65
98
|
|
|
66
99
|
for data in statistic_data_list:
|
|
67
100
|
compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}"
|
|
101
|
+
op_name = f"{compare_key} {statistic_file_path}"
|
|
68
102
|
timestamp = int(data[4])
|
|
69
|
-
result_data = [
|
|
103
|
+
result_data = [op_name, compare_key, timestamp]
|
|
70
104
|
for key in header_index.keys():
|
|
71
105
|
if header_index[key] is None:
|
|
72
106
|
result_data.append(np.nan)
|
|
@@ -94,11 +128,9 @@ def generate_data_name(data_path):
|
|
|
94
128
|
mapping_dict = {}
|
|
95
129
|
if mapping_exist:
|
|
96
130
|
for mapping_file in mapping_file_list:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
for row in csv_reader:
|
|
101
|
-
mapping_dict[row[0]] = row[1]
|
|
131
|
+
content = read_csv(mapping_file, False)
|
|
132
|
+
for row in content[1:]:
|
|
133
|
+
mapping_dict[row[0]] = row[1]
|
|
102
134
|
|
|
103
135
|
if npy_exist:
|
|
104
136
|
data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
|
|
@@ -133,7 +165,7 @@ class GraphMSComparator:
|
|
|
133
165
|
def compare_ops(compare_result_db, mode):
|
|
134
166
|
|
|
135
167
|
def npy_mode_compute(row):
|
|
136
|
-
result_dict =
|
|
168
|
+
result_dict = RowData(GraphMode.NPY_MODE)()
|
|
137
169
|
|
|
138
170
|
def process_npy_file(file_path, name_prefix, result):
|
|
139
171
|
if os.path.exists(file_path):
|
|
@@ -168,7 +200,7 @@ class GraphMSComparator:
|
|
|
168
200
|
return pd.Series(result_dict)
|
|
169
201
|
|
|
170
202
|
def statistic_mode_compute(row):
|
|
171
|
-
result_dict =
|
|
203
|
+
result_dict = RowData('STATISTIC')()
|
|
172
204
|
|
|
173
205
|
def update_result_dict(result, rows, prefix):
|
|
174
206
|
result[f'{prefix} Name'] = rows[f'{prefix} Name']
|
|
@@ -195,24 +227,30 @@ class GraphMSComparator:
|
|
|
195
227
|
result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
|
|
196
228
|
result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
|
|
197
229
|
CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
|
|
198
|
-
|
|
230
|
+
if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
|
|
231
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
|
|
232
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
|
|
199
233
|
result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
|
|
200
234
|
CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
|
|
201
|
-
|
|
235
|
+
if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
|
|
236
|
+
result_dict[CompareConst.MIN_RELATIVE_ERR] = \
|
|
237
|
+
str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
|
|
202
238
|
result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
|
|
203
239
|
CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
|
|
204
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
205
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR]
|
|
240
|
+
if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
|
|
241
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
|
|
242
|
+
result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
|
|
206
243
|
result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
|
|
207
244
|
CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
|
|
208
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
209
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR]
|
|
245
|
+
if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
|
|
246
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
|
|
247
|
+
result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
|
|
210
248
|
magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
|
|
211
249
|
max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
|
|
212
|
-
if
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
250
|
+
if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
|
|
251
|
+
magnitude_diff = 0
|
|
252
|
+
result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
|
|
253
|
+
magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
|
|
216
254
|
|
|
217
255
|
return pd.Series(result_dict)
|
|
218
256
|
|
|
@@ -235,14 +273,24 @@ class GraphMSComparator:
|
|
|
235
273
|
is_empty = True
|
|
236
274
|
if is_empty or not mode:
|
|
237
275
|
continue
|
|
238
|
-
compare_result_df = self.
|
|
276
|
+
compare_result_df = self.do_multi_process(compare_result_df, mode)
|
|
239
277
|
compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
|
|
240
278
|
compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
|
|
241
|
-
|
|
242
|
-
compare_result_df.to_excel(compare_result_path, index=False)
|
|
243
|
-
change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
279
|
+
self.to_excel(compare_result_df, compare_result_path)
|
|
244
280
|
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
245
281
|
|
|
282
|
+
def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
|
|
283
|
+
size = len(compare_result_df)
|
|
284
|
+
# sheet size cannot be larger than 1048576
|
|
285
|
+
if size < CompareConst.MAX_EXCEL_LENGTH:
|
|
286
|
+
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
|
|
287
|
+
need_slice else compare_result_path
|
|
288
|
+
save_excel(compare_result_path, compare_result_df)
|
|
289
|
+
return slice_num + 1
|
|
290
|
+
else:
|
|
291
|
+
slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
|
|
292
|
+
return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
|
|
293
|
+
|
|
246
294
|
def compare_process(self, rank_id, step_id):
|
|
247
295
|
# generate data_path
|
|
248
296
|
npu_data_path_list = self.npu_rank_step_dict.get((rank_id, step_id))
|
|
@@ -251,8 +299,8 @@ class GraphMSComparator:
|
|
|
251
299
|
return [], ''
|
|
252
300
|
|
|
253
301
|
# generate file name
|
|
254
|
-
npu_mode =
|
|
255
|
-
bench_mode =
|
|
302
|
+
npu_mode = GraphMode.ERROR_MODE
|
|
303
|
+
bench_mode = GraphMode.ERROR_MODE
|
|
256
304
|
npu_data_list = []
|
|
257
305
|
bench_data_list = []
|
|
258
306
|
for npu_data_path in npu_data_path_list:
|
|
@@ -262,7 +310,7 @@ class GraphMSComparator:
|
|
|
262
310
|
bench_mode, data_list = generate_data_name(bench_data_path)
|
|
263
311
|
bench_data_list.extend(data_list)
|
|
264
312
|
|
|
265
|
-
if npu_mode ==
|
|
313
|
+
if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
|
|
266
314
|
logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.")
|
|
267
315
|
return [], ''
|
|
268
316
|
if npu_mode != bench_mode:
|
|
@@ -286,11 +334,13 @@ class GraphMSComparator:
|
|
|
286
334
|
CompareConst.BENCH_NORM])
|
|
287
335
|
|
|
288
336
|
npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
289
|
-
npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(
|
|
337
|
+
npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
|
|
290
338
|
|
|
291
|
-
bench_float_type = [
|
|
292
|
-
|
|
293
|
-
|
|
339
|
+
bench_float_type = [
|
|
340
|
+
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
341
|
+
CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
|
|
342
|
+
]
|
|
343
|
+
bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
|
|
294
344
|
|
|
295
345
|
npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
296
346
|
bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
|
|
@@ -339,7 +389,7 @@ class GraphMSComparator:
|
|
|
339
389
|
rank_step_path_dict[rank_step_key] = [dir_path]
|
|
340
390
|
return dict(sorted(rank_step_path_dict.items()))
|
|
341
391
|
|
|
342
|
-
def
|
|
392
|
+
def do_multi_process(self, result_df, mode):
|
|
343
393
|
try:
|
|
344
394
|
result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
|
|
345
395
|
except ValueError as e:
|
|
@@ -1,9 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory
|
|
4
20
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
5
21
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
6
|
-
from msprobe.core.common.file_utils import create_directory
|
|
7
22
|
|
|
8
23
|
|
|
9
24
|
class DebuggerConfig:
|
|
@@ -18,7 +33,7 @@ class DebuggerConfig:
|
|
|
18
33
|
self.level_ori = common_config.level
|
|
19
34
|
self.list = [] if not task_config.list else task_config.list
|
|
20
35
|
self.scope = [] if not task_config.scope else task_config.scope
|
|
21
|
-
self.data_mode = [] if not task_config.data_mode else task_config.data_mode
|
|
36
|
+
self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
|
|
22
37
|
self.file_format = task_config.file_format
|
|
23
38
|
self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
|
|
24
39
|
self.check_mode = task_config.check_mode
|
|
@@ -37,6 +52,9 @@ class DebuggerConfig:
|
|
|
37
52
|
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
38
53
|
raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
39
54
|
f"but got {self.pert_type}.")
|
|
55
|
+
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
56
|
+
raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
|
|
57
|
+
f"but got {self.handler_type}.")
|
|
40
58
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
41
59
|
|
|
42
60
|
def check(self):
|
|
@@ -51,16 +69,4 @@ class DebuggerConfig:
|
|
|
51
69
|
self.file_format = "npy"
|
|
52
70
|
if not self.check_mode:
|
|
53
71
|
self.check_mode = "all"
|
|
54
|
-
self._check_rank()
|
|
55
|
-
self._check_step()
|
|
56
72
|
return True
|
|
57
|
-
|
|
58
|
-
def _check_rank(self):
|
|
59
|
-
for rank_id in self.rank:
|
|
60
|
-
if not isinstance(rank_id, int) or rank_id < 0:
|
|
61
|
-
raise ValueError(f"rank {self.rank} must be a positive integer.")
|
|
62
|
-
|
|
63
|
-
def _check_step(self):
|
|
64
|
-
for s in self.step:
|
|
65
|
-
if not isinstance(s, int) or s < 0:
|
|
66
|
-
raise ValueError(f"step element {s} must be a positive integer.")
|
|
@@ -1,17 +1,34 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
17
|
+
from collections import defaultdict
|
|
2
18
|
|
|
3
19
|
import mindspore as ms
|
|
4
20
|
from mindspore._c_expression import MSContext
|
|
5
21
|
|
|
6
|
-
from msprobe.
|
|
7
|
-
from msprobe.mindspore.
|
|
8
|
-
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
9
|
-
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
10
|
-
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.const import Const, MsgConst
|
|
23
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
11
24
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
12
|
-
from msprobe.mindspore.
|
|
13
|
-
|
|
25
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
26
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
14
27
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
28
|
+
from msprobe.mindspore.ms_config import parse_json_config
|
|
29
|
+
from msprobe.mindspore.runtime import Runtime
|
|
30
|
+
from msprobe.mindspore.service import Service
|
|
31
|
+
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
|
|
15
32
|
|
|
16
33
|
|
|
17
34
|
class PrecisionDebugger:
|
|
@@ -65,11 +82,11 @@ class PrecisionDebugger:
|
|
|
65
82
|
def start(cls, model=None):
|
|
66
83
|
instance = cls._instance
|
|
67
84
|
if not instance:
|
|
68
|
-
raise Exception(
|
|
85
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
69
86
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
70
87
|
return
|
|
71
88
|
|
|
72
|
-
instance.config.execution_mode =
|
|
89
|
+
instance.config.execution_mode = cls._get_execution_mode()
|
|
73
90
|
if cls._need_service():
|
|
74
91
|
if not instance.service:
|
|
75
92
|
instance.service = Service(instance.config)
|
|
@@ -82,11 +99,21 @@ class PrecisionDebugger:
|
|
|
82
99
|
instance.first_start = True
|
|
83
100
|
Runtime.is_running = True
|
|
84
101
|
|
|
102
|
+
@classmethod
|
|
103
|
+
def forward_backward_dump_end(cls):
|
|
104
|
+
instance = cls._instance
|
|
105
|
+
if not instance:
|
|
106
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
107
|
+
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
108
|
+
return
|
|
109
|
+
if instance.service:
|
|
110
|
+
instance.service.forward_backward_dump_end()
|
|
111
|
+
|
|
85
112
|
@classmethod
|
|
86
113
|
def stop(cls):
|
|
87
114
|
instance = cls._instance
|
|
88
115
|
if not instance:
|
|
89
|
-
raise Exception(
|
|
116
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
90
117
|
if instance.task == Const.GRAD_PROBE:
|
|
91
118
|
instance.gm.stop()
|
|
92
119
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
@@ -99,18 +126,21 @@ class PrecisionDebugger:
|
|
|
99
126
|
def step(cls):
|
|
100
127
|
instance = cls._instance
|
|
101
128
|
if not instance:
|
|
102
|
-
raise Exception(
|
|
129
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
103
130
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
104
131
|
return
|
|
105
132
|
if instance.service:
|
|
106
133
|
instance.service.step()
|
|
134
|
+
HOOKCell.cell_count = defaultdict(int)
|
|
135
|
+
CellProcessor.reset_cell_stats()
|
|
136
|
+
|
|
107
137
|
Runtime.step_count += 1
|
|
108
138
|
|
|
109
139
|
@classmethod
|
|
110
140
|
def monitor(cls, opt):
|
|
111
141
|
instance = cls._instance
|
|
112
142
|
if not instance:
|
|
113
|
-
raise Exception(
|
|
143
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
114
144
|
if instance.task != Const.GRAD_PROBE:
|
|
115
145
|
return
|
|
116
146
|
instance.gm.monitor(opt)
|
|
@@ -119,7 +149,7 @@ class PrecisionDebugger:
|
|
|
119
149
|
def _need_service(cls):
|
|
120
150
|
instance = cls._instance
|
|
121
151
|
if not instance:
|
|
122
|
-
raise Exception(
|
|
152
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
123
153
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
124
154
|
return False
|
|
125
155
|
else:
|
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.common.const import Const
|
|
2
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
3
|
-
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
4
18
|
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
19
|
+
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
class DumpToolFactory:
|
|
@@ -25,6 +40,8 @@ class DumpToolFactory:
|
|
|
25
40
|
|
|
26
41
|
@staticmethod
|
|
27
42
|
def create(config: DebuggerConfig):
|
|
43
|
+
if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
|
|
44
|
+
raise Exception("data_mode must be one of all, input, output.")
|
|
28
45
|
tool = DumpToolFactory.tools.get(config.level)
|
|
29
46
|
if not tool:
|
|
30
47
|
raise Exception("Valid level is needed.")
|
|
@@ -16,13 +16,20 @@
|
|
|
16
16
|
from mindspore import Tensor, ops, mint
|
|
17
17
|
from mindspore.mint.nn import functional
|
|
18
18
|
from mindspore.common._stub_tensor import StubTensor
|
|
19
|
+
from mindspore.communication import comm_func
|
|
19
20
|
|
|
20
21
|
from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
|
|
21
|
-
HOOKMintOP, HOOKMintNNFunctionalOP,
|
|
22
|
+
HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
|
|
22
23
|
get_wrap_api_list, setup_hooks)
|
|
23
24
|
from msprobe.core.common.utils import Const
|
|
24
25
|
|
|
25
26
|
|
|
27
|
+
def stub_method(method):
|
|
28
|
+
def wrapped_method(*args, **kwargs):
|
|
29
|
+
return method(*args, **kwargs)
|
|
30
|
+
return wrapped_method
|
|
31
|
+
|
|
32
|
+
|
|
26
33
|
class ApiRegistry:
|
|
27
34
|
def __init__(self):
|
|
28
35
|
self.tensor_ori_attr = {}
|
|
@@ -30,6 +37,7 @@ class ApiRegistry:
|
|
|
30
37
|
self.functional_ori_attr = {}
|
|
31
38
|
self.mint_ops_ori_attr = {}
|
|
32
39
|
self.mint_func_ops_ori_attr = {}
|
|
40
|
+
self.distributed_ori_attr = {}
|
|
33
41
|
self.norm_inner_ops_ori_attr = {}
|
|
34
42
|
|
|
35
43
|
self.tensor_hook_attr = {}
|
|
@@ -37,6 +45,7 @@ class ApiRegistry:
|
|
|
37
45
|
self.functional_hook_attr = {}
|
|
38
46
|
self.mint_ops_hook_attr = {}
|
|
39
47
|
self.mint_func_ops_hook_attr = {}
|
|
48
|
+
self.distibuted_hook_attr = {}
|
|
40
49
|
self.norm_inner_ops_hook_attr = {}
|
|
41
50
|
|
|
42
51
|
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
@@ -47,9 +56,13 @@ class ApiRegistry:
|
|
|
47
56
|
if Const.SEP in api:
|
|
48
57
|
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
49
58
|
sub_module = getattr(ori_api_group, sub_module_name)
|
|
50
|
-
|
|
59
|
+
ori_api_func = getattr(sub_module, sub_op)
|
|
51
60
|
else:
|
|
52
|
-
|
|
61
|
+
ori_api_func = getattr(ori_api_group, api)
|
|
62
|
+
if ori_api_group == StubTensor:
|
|
63
|
+
api_ori_attr[api] = stub_method(ori_api_func)
|
|
64
|
+
continue
|
|
65
|
+
api_ori_attr[api] = ori_api_func
|
|
53
66
|
|
|
54
67
|
@staticmethod
|
|
55
68
|
def set_api_attr(api_group, attr_dict):
|
|
@@ -74,6 +87,7 @@ class ApiRegistry:
|
|
|
74
87
|
self.set_api_attr(ops, self.functional_hook_attr)
|
|
75
88
|
self.set_api_attr(mint, self.mint_ops_hook_attr)
|
|
76
89
|
self.set_api_attr(functional, self.mint_func_ops_hook_attr)
|
|
90
|
+
self.set_api_attr(comm_func, self.distibuted_hook_attr)
|
|
77
91
|
|
|
78
92
|
def api_set_ori_func(self):
|
|
79
93
|
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
@@ -81,6 +95,7 @@ class ApiRegistry:
|
|
|
81
95
|
self.set_api_attr(ops, self.functional_ori_attr)
|
|
82
96
|
self.set_api_attr(mint, self.mint_ops_ori_attr)
|
|
83
97
|
self.set_api_attr(functional, self.mint_func_ops_ori_attr)
|
|
98
|
+
self.set_api_attr(comm_func, self.distributed_ori_attr)
|
|
84
99
|
|
|
85
100
|
def initialize_hook(self, hook):
|
|
86
101
|
wrap_api_name = get_wrap_api_list()
|
|
@@ -89,6 +104,7 @@ class ApiRegistry:
|
|
|
89
104
|
self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
|
|
90
105
|
self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
|
|
91
106
|
self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
|
|
107
|
+
self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
|
|
92
108
|
self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
93
109
|
setup_hooks(hook)
|
|
94
110
|
for attr_name in dir(HOOKTensor):
|
|
@@ -113,6 +129,10 @@ class ApiRegistry:
|
|
|
113
129
|
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
114
130
|
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
115
131
|
self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
|
|
132
|
+
for attr_name in dir(HOOKDistributedOP):
|
|
133
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
134
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
135
|
+
self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
|
|
116
136
|
|
|
117
137
|
|
|
118
138
|
api_register = ApiRegistry()
|