mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,37 +12,33 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
import copy
|
|
18
17
|
import functools
|
|
18
|
+
import os
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
|
-
from mindspore.common.tensor import Tensor
|
|
23
|
-
from mindspore import ops
|
|
24
22
|
from mindspore import nn
|
|
25
23
|
try:
|
|
26
24
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
|
-
pijit_label = True
|
|
28
25
|
except ImportError:
|
|
29
26
|
pijit_label = False
|
|
27
|
+
else:
|
|
28
|
+
pijit_label = True
|
|
30
29
|
|
|
31
30
|
|
|
31
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
32
|
+
from msprobe.core.common.file_utils import create_directory
|
|
33
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
32
34
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
35
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
|
|
33
36
|
from msprobe.core.data_dump.scope import BaseScope
|
|
34
|
-
from msprobe.mindspore.
|
|
35
|
-
from msprobe.core.common.file_utils import create_directory
|
|
37
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
36
38
|
from msprobe.mindspore.common.log import logger
|
|
37
|
-
from msprobe.
|
|
38
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
39
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
39
40
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
40
|
-
from msprobe.
|
|
41
|
-
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
42
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
43
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
44
|
-
from msprobe.mindspore.cell_processor import CellProcessor
|
|
41
|
+
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
45
42
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
46
43
|
|
|
47
44
|
|
|
@@ -52,11 +49,12 @@ class Service:
|
|
|
52
49
|
self.config.level = self.config.level_ori
|
|
53
50
|
self.data_collector = build_data_collector(self.config)
|
|
54
51
|
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
52
|
+
self.primitive_hook_service = PrimitiveHookService(self)
|
|
55
53
|
self.switch = False
|
|
54
|
+
self.primitive_switch = False
|
|
56
55
|
self.current_iter = 0
|
|
57
56
|
self.first_start = True
|
|
58
57
|
self.current_rank = None
|
|
59
|
-
self.primitive_counters = {}
|
|
60
58
|
self.dump_iter_dir = None
|
|
61
59
|
self.start_call = False
|
|
62
60
|
self.check_level_valid()
|
|
@@ -71,28 +69,30 @@ class Service:
|
|
|
71
69
|
)
|
|
72
70
|
|
|
73
71
|
def check_level_valid(self):
|
|
74
|
-
if self.config.level ==
|
|
72
|
+
if self.config.level == Const.LEVEL_L2:
|
|
75
73
|
raise MsprobeException(
|
|
76
74
|
MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
|
|
77
75
|
)
|
|
78
76
|
|
|
79
77
|
def build_hook(self, target_type, name):
|
|
80
|
-
def forward_hook(api_or_cell_name, cell,
|
|
78
|
+
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
81
79
|
if not self.should_excute_hook():
|
|
80
|
+
if hasattr(cell, 'input_kwargs'):
|
|
81
|
+
del cell.input_kwargs
|
|
82
82
|
return None
|
|
83
83
|
|
|
84
84
|
if target_type == BaseScope.Module_Type_Module:
|
|
85
|
-
api_or_cell_name = cell
|
|
86
|
-
module_input_output = ModuleForwardInputsOutputs(args=
|
|
85
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
86
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
87
87
|
else:
|
|
88
|
-
module_input_output = ModuleForwardInputsOutputs(args=
|
|
88
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
|
|
89
89
|
output=output)
|
|
90
90
|
|
|
91
91
|
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
92
92
|
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
93
93
|
if self.data_collector.if_return_forward_new_output():
|
|
94
94
|
return self.data_collector.get_forward_new_output()
|
|
95
|
-
if
|
|
95
|
+
if hasattr(cell, 'input_kwargs'):
|
|
96
96
|
del cell.input_kwargs
|
|
97
97
|
return output
|
|
98
98
|
|
|
@@ -100,12 +100,19 @@ class Service:
|
|
|
100
100
|
if not self.should_excute_hook():
|
|
101
101
|
return
|
|
102
102
|
|
|
103
|
+
need_exchange = True
|
|
103
104
|
if target_type == BaseScope.Module_Type_Module:
|
|
104
|
-
|
|
105
|
+
if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
|
|
106
|
+
need_exchange = False
|
|
107
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
108
|
+
|
|
105
109
|
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
106
110
|
if self.data_collector:
|
|
107
111
|
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
108
|
-
|
|
112
|
+
if need_exchange:
|
|
113
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
114
|
+
else:
|
|
115
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
109
116
|
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
110
117
|
|
|
111
118
|
pid = os.getpid()
|
|
@@ -114,145 +121,40 @@ class Service:
|
|
|
114
121
|
forward_hook = functools.partial(forward_hook, forward_name_template)
|
|
115
122
|
backward_hook = functools.partial(backward_hook, backward_name_template)
|
|
116
123
|
|
|
117
|
-
def wrap_forward_hook(cell,
|
|
118
|
-
return forward_hook(cell,
|
|
124
|
+
def wrap_forward_hook(cell, input_data, output_data):
|
|
125
|
+
return forward_hook(cell, input_data, output_data)
|
|
119
126
|
|
|
120
127
|
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
121
128
|
return backward_hook(cell, grad_input, grad_output)
|
|
122
129
|
|
|
123
130
|
return wrap_forward_hook, wrap_backward_hook
|
|
124
131
|
|
|
125
|
-
def wrap_primitive(self, origin_func, primitive_name):
|
|
126
|
-
service_instance = self
|
|
127
|
-
|
|
128
|
-
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
|
|
129
|
-
def backward_hook(grad):
|
|
130
|
-
captured_grads.append(grad)
|
|
131
|
-
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
|
|
132
|
-
try:
|
|
133
|
-
if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
|
|
134
|
-
service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
135
|
-
new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
|
|
136
|
-
service_instance.data_collector.backward_output_data_collect(
|
|
137
|
-
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
|
|
138
|
-
)
|
|
139
|
-
captured_grads.clear()
|
|
140
|
-
elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
|
|
141
|
-
service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
142
|
-
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
|
|
143
|
-
service_instance.data_collector.backward_input_data_collect(
|
|
144
|
-
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
|
|
145
|
-
)
|
|
146
|
-
captured_grads.clear()
|
|
147
|
-
|
|
148
|
-
except Exception as exception:
|
|
149
|
-
raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
|
|
150
|
-
f" updated_primitive_name: {updated_primitive_name}") from exception
|
|
151
|
-
|
|
152
|
-
return backward_hook
|
|
153
|
-
|
|
154
|
-
def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
|
|
155
|
-
hooked_inputs = []
|
|
156
|
-
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
|
|
157
|
-
input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
|
|
158
|
-
Const.INPUT)
|
|
159
|
-
for _, arg in enumerate(args):
|
|
160
|
-
if isinstance(arg, Tensor):
|
|
161
|
-
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
|
|
162
|
-
hooked_inputs.append(arg_hooked)
|
|
163
|
-
else:
|
|
164
|
-
hooked_inputs.append(arg)
|
|
165
|
-
return hooked_inputs
|
|
166
|
-
|
|
167
|
-
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
|
|
168
|
-
if isinstance(out, tuple):
|
|
169
|
-
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
|
|
170
|
-
else:
|
|
171
|
-
num_output_tensors = 1
|
|
172
|
-
output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
|
|
173
|
-
updated_primitive_name, Const.OUTPUT)
|
|
174
|
-
|
|
175
|
-
if isinstance(out, Tensor):
|
|
176
|
-
return ops.HookBackward(output_backward_hook)(out)
|
|
177
|
-
elif isinstance(out, tuple):
|
|
178
|
-
hooked_outputs = []
|
|
179
|
-
for tensor in out:
|
|
180
|
-
if isinstance(tensor, Tensor):
|
|
181
|
-
hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
|
|
182
|
-
else:
|
|
183
|
-
hooked_outputs.append(tensor)
|
|
184
|
-
return tuple(hooked_outputs)
|
|
185
|
-
return out
|
|
186
|
-
|
|
187
|
-
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
188
|
-
service_instance.update_primitive_counters(primitive_name)
|
|
189
|
-
current_count = service_instance.primitive_counters.get(primitive_name, 0)
|
|
190
|
-
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
|
|
191
|
-
|
|
192
|
-
if not service_instance.switch:
|
|
193
|
-
return origin_func(*args, **kwargs)
|
|
194
|
-
|
|
195
|
-
captured_grads_input, captured_grads_output = [], []
|
|
196
|
-
|
|
197
|
-
try:
|
|
198
|
-
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
|
|
199
|
-
except Exception as exception:
|
|
200
|
-
raise Exception("This is a primitive op dump error during input hooking: {},"
|
|
201
|
-
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
202
|
-
|
|
203
|
-
try:
|
|
204
|
-
out = origin_func(*hooked_inputs, **kwargs)
|
|
205
|
-
except Exception as exception:
|
|
206
|
-
raise Exception("This is a primitive op dump error during function call: {},"
|
|
207
|
-
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
208
|
-
|
|
209
|
-
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
|
|
210
|
-
service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
211
|
-
if service_instance.data_collector:
|
|
212
|
-
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
213
|
-
try:
|
|
214
|
-
service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
215
|
-
os.getpid(), module_input_output)
|
|
216
|
-
except Exception as exception:
|
|
217
|
-
raise Exception("This is a primitive op dump error during forward data collection: {},"
|
|
218
|
-
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
219
|
-
|
|
220
|
-
if service_instance.data_collector.if_return_forward_new_output():
|
|
221
|
-
out = service_instance.data_collector.get_forward_new_output()
|
|
222
|
-
|
|
223
|
-
try:
|
|
224
|
-
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
225
|
-
except Exception as exception:
|
|
226
|
-
raise Exception("This is a primitive op dump error during output hooking: {},"
|
|
227
|
-
" primitive_name: {}".format(exception, primitive_name)) from exception
|
|
228
|
-
|
|
229
|
-
return out
|
|
230
|
-
|
|
231
|
-
return wrapped_primitive_call
|
|
232
|
-
|
|
233
132
|
def update_primitive_counters(self, primitive_name):
|
|
234
133
|
if primitive_name not in self.primitive_counters:
|
|
235
134
|
self.primitive_counters[primitive_name] = 0
|
|
236
135
|
else:
|
|
237
136
|
self.primitive_counters[primitive_name] += 1
|
|
238
137
|
|
|
239
|
-
def
|
|
138
|
+
def register_primitive_hooks(self):
|
|
240
139
|
primitive_set = set()
|
|
241
140
|
for _, cell in self.model.cells_and_names():
|
|
242
141
|
for pname, primitive in cell._primitives.items():
|
|
243
142
|
primitive_set.add((pname, primitive))
|
|
244
143
|
|
|
245
144
|
for pname, primitive in primitive_set:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
145
|
+
primitive_class_name = primitive.__class__.__name__
|
|
146
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
147
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
148
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
149
|
+
primitive_combined_name)})
|
|
150
|
+
primitive.__class__ = new_primitive
|
|
249
151
|
|
|
250
152
|
def step(self):
|
|
251
153
|
self.current_iter += 1
|
|
252
154
|
self.data_collector.update_iter(self.current_iter)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
155
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
156
|
+
self.data_collector.data_writer.reset_cache()
|
|
157
|
+
JitDump.jit_count = defaultdict(int)
|
|
256
158
|
|
|
257
159
|
def start(self, model=None):
|
|
258
160
|
self.start_call = True
|
|
@@ -262,9 +164,8 @@ class Service:
|
|
|
262
164
|
api_register.api_set_ori_func()
|
|
263
165
|
self.should_stop_service = True
|
|
264
166
|
self.switch = False
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
logger.info("************************************************")
|
|
167
|
+
self.primitive_switch = False
|
|
168
|
+
print_tools_ends_info()
|
|
268
169
|
return
|
|
269
170
|
if self.config.step and self.current_iter not in self.config.step:
|
|
270
171
|
return
|
|
@@ -281,7 +182,7 @@ class Service:
|
|
|
281
182
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
282
183
|
return
|
|
283
184
|
self.register_hook_new()
|
|
284
|
-
if self.config.level
|
|
185
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
285
186
|
JitDump.set_config(self.config)
|
|
286
187
|
JitDump.set_data_collector(self.data_collector)
|
|
287
188
|
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
@@ -291,10 +192,32 @@ class Service:
|
|
|
291
192
|
PIJitCaptureContext.__exit__ = self.empty
|
|
292
193
|
self.first_start = False
|
|
293
194
|
|
|
195
|
+
api_register.api_set_hook_func()
|
|
294
196
|
self.switch = True
|
|
197
|
+
self.primitive_switch = True
|
|
295
198
|
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
296
199
|
self.create_dirs()
|
|
297
200
|
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
201
|
+
JitDump.jit_dump_switch = True
|
|
202
|
+
|
|
203
|
+
def forward_backward_dump_end(self):
|
|
204
|
+
if self.should_stop_service:
|
|
205
|
+
return
|
|
206
|
+
logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
|
|
207
|
+
if not self.start_call:
|
|
208
|
+
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
209
|
+
raise Exception("debugger.start() is not set in the current scope.")
|
|
210
|
+
if not self.switch:
|
|
211
|
+
logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
|
|
212
|
+
"debugger.start() and debugger.stop() ")
|
|
213
|
+
raise Exception("debugger.stop() is already called. ")
|
|
214
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
215
|
+
return
|
|
216
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
217
|
+
return
|
|
218
|
+
self.primitive_switch = False
|
|
219
|
+
api_register.api_set_ori_func()
|
|
220
|
+
JitDump.jit_dump_switch = False
|
|
298
221
|
|
|
299
222
|
def stop(self):
|
|
300
223
|
if self.should_stop_service:
|
|
@@ -309,8 +232,10 @@ class Service:
|
|
|
309
232
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
310
233
|
return
|
|
311
234
|
self.switch = False
|
|
235
|
+
self.primitive_switch = False
|
|
312
236
|
self.start_call = False
|
|
313
237
|
self.data_collector.write_json()
|
|
238
|
+
JitDump.jit_dump_switch = False
|
|
314
239
|
|
|
315
240
|
def need_end_service(self):
|
|
316
241
|
if self.config.step and self.current_iter > max(self.config.step):
|
|
@@ -349,16 +274,16 @@ class Service:
|
|
|
349
274
|
|
|
350
275
|
def register_hook_new(self):
|
|
351
276
|
logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
|
|
352
|
-
if self.config.level
|
|
277
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
353
278
|
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
354
279
|
api_register.api_set_hook_func()
|
|
355
|
-
if self.model:
|
|
356
|
-
self.
|
|
280
|
+
if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
|
|
281
|
+
self.register_primitive_hooks()
|
|
357
282
|
|
|
358
|
-
if self.config.level
|
|
283
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
359
284
|
if not self.model:
|
|
360
285
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
361
|
-
"The current level is
|
|
286
|
+
f"The current level is {self.config.level}, the model cannot be None")
|
|
362
287
|
for name, cell in self.model.cells_and_names():
|
|
363
288
|
if cell == self.model:
|
|
364
289
|
continue
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const
|
|
2
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
3
18
|
from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
|
msprobe/msprobe.py
CHANGED
|
@@ -45,10 +45,15 @@ def main():
|
|
|
45
45
|
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
46
46
|
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
47
47
|
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
48
|
+
graph_service_cmd_parser = subparsers.add_parser('graph')
|
|
48
49
|
_compare_parser(compare_cmd_parser)
|
|
49
|
-
is_torch_available=is_module_available("torch")
|
|
50
|
+
is_torch_available = is_module_available("torch")
|
|
50
51
|
is_mindspore_available = is_module_available("mindspore")
|
|
51
|
-
if
|
|
52
|
+
if len(sys.argv) < 4:
|
|
53
|
+
parser.print_help()
|
|
54
|
+
sys.exit(0)
|
|
55
|
+
framework_args = parser.parse_args(sys.argv[1:3])
|
|
56
|
+
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
52
57
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
53
58
|
from msprobe.pytorch.parse_tool.cli import parse as cli_parse
|
|
54
59
|
from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
|
|
@@ -56,20 +61,24 @@ def main():
|
|
|
56
61
|
_api_precision_compare_command
|
|
57
62
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
58
63
|
_run_overflow_check_command
|
|
64
|
+
from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
|
|
59
65
|
|
|
60
66
|
_run_ut_parser(run_ut_cmd_parser)
|
|
61
67
|
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
62
68
|
multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
63
|
-
|
|
69
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
64
70
|
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
65
71
|
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
66
|
-
|
|
72
|
+
_pt_graph_service_parser(graph_service_cmd_parser)
|
|
73
|
+
elif framework_args.framework == Const.MS_FRAMEWORK:
|
|
67
74
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
|
|
75
|
+
from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
|
|
68
76
|
add_api_accuracy_checker_argument(run_ut_cmd_parser)
|
|
77
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
|
|
78
|
+
multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
|
|
79
|
+
|
|
80
|
+
_ms_graph_service_parser(graph_service_cmd_parser)
|
|
69
81
|
|
|
70
|
-
if len(sys.argv) == 1:
|
|
71
|
-
parser.print_help()
|
|
72
|
-
sys.exit(0)
|
|
73
82
|
args = parser.parse_args(sys.argv[1:])
|
|
74
83
|
if sys.argv[2] == Const.PT_FRAMEWORK:
|
|
75
84
|
if not is_torch_available:
|
|
@@ -86,6 +95,8 @@ def main():
|
|
|
86
95
|
_api_precision_compare_command(args)
|
|
87
96
|
elif sys.argv[3] == "run_overflow_check":
|
|
88
97
|
_run_overflow_check_command(args)
|
|
98
|
+
elif sys.argv[3] == "graph":
|
|
99
|
+
_pt_graph_service_command(args)
|
|
89
100
|
elif sys.argv[3] == "compare":
|
|
90
101
|
if args.cell_mapping is not None or args.api_mapping is not None:
|
|
91
102
|
logger.error("Argument -cm or -am is not supported in PyTorch framework")
|
|
@@ -100,6 +111,12 @@ def main():
|
|
|
100
111
|
elif sys.argv[3] == "run_ut":
|
|
101
112
|
from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
|
|
102
113
|
api_checker_main(args)
|
|
114
|
+
elif sys.argv[3] == "multi_run_ut":
|
|
115
|
+
from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
|
|
116
|
+
mul_api_checker_main(args)
|
|
117
|
+
elif sys.argv[3] == "graph":
|
|
118
|
+
_ms_graph_service_command(args)
|
|
119
|
+
|
|
103
120
|
|
|
104
121
|
if __name__ == "__main__":
|
|
105
122
|
main()
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -1,4 +1,24 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from msprobe.pytorch.monitor.module_hook import TrainerMon
|
|
3
20
|
from .compare.distributed_compare import compare_distributed
|
|
4
|
-
from .compare.pt_compare import compare
|
|
21
|
+
from .compare.pt_compare import compare
|
|
22
|
+
from .common.utils import seed_all
|
|
23
|
+
from .debugger.precision_debugger import PrecisionDebugger
|
|
24
|
+
from .functional.module_dump import module_dump, module_dump_end
|
|
@@ -1,8 +1,33 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
1
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
2
20
|
from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
|
|
21
|
+
from msprobe.core.common.utils import is_int
|
|
3
22
|
from msprobe.pytorch.pt_config import RunUTConfig
|
|
4
23
|
|
|
5
24
|
|
|
25
|
+
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
|
+
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
28
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
29
|
+
|
|
30
|
+
|
|
6
31
|
class Config:
|
|
7
32
|
def __init__(self, yaml_file):
|
|
8
33
|
check_file_or_directory_path(yaml_file, False)
|
|
@@ -33,8 +58,10 @@ class Config:
|
|
|
33
58
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
34
59
|
if not isinstance(value, validators.get(key)):
|
|
35
60
|
raise ValueError(f"{key} must be {validators[key].__name__} type")
|
|
36
|
-
if key == 'precision' and value
|
|
37
|
-
raise ValueError("precision must be
|
|
61
|
+
if key == 'precision' and not is_int(value):
|
|
62
|
+
raise ValueError("precision must be an integer")
|
|
63
|
+
if key == 'precision' and (value < 0 or value > 20):
|
|
64
|
+
raise ValueError("precision must be greater than or equal to 0 and less than 21")
|
|
38
65
|
if key == 'white_list':
|
|
39
66
|
RunUTConfig.check_filter_list_config(key, value)
|
|
40
67
|
if key == 'black_list':
|
|
@@ -51,3 +78,55 @@ class Config:
|
|
|
51
78
|
cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
52
79
|
yaml_path = os.path.join(cur_path, "config.yaml")
|
|
53
80
|
msCheckerConfig = Config(yaml_path)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CheckerConfig:
|
|
84
|
+
def __init__(self, task_config=None):
|
|
85
|
+
self.white_list = msCheckerConfig.white_list
|
|
86
|
+
self.black_list = msCheckerConfig.black_list
|
|
87
|
+
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
+
self.is_online = msCheckerConfig.is_online
|
|
89
|
+
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
+
self.host = msCheckerConfig.host
|
|
91
|
+
self.port = msCheckerConfig.port
|
|
92
|
+
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
+
self.tls_path = msCheckerConfig.tls_path
|
|
94
|
+
|
|
95
|
+
if task_config:
|
|
96
|
+
self.load_config(task_config)
|
|
97
|
+
|
|
98
|
+
def load_config(self, task_config):
|
|
99
|
+
self.white_list = task_config.white_list
|
|
100
|
+
self.black_list = task_config.black_list
|
|
101
|
+
self.error_data_path = task_config.error_data_path
|
|
102
|
+
self.is_online = task_config.is_online
|
|
103
|
+
self.nfs_path = task_config.nfs_path
|
|
104
|
+
self.host = task_config.host
|
|
105
|
+
self.port = task_config.port
|
|
106
|
+
self.rank_list = task_config.rank_list
|
|
107
|
+
self.tls_path = task_config.tls_path
|
|
108
|
+
|
|
109
|
+
def get_online_config(self):
|
|
110
|
+
return OnlineConfig(
|
|
111
|
+
is_online=self.is_online,
|
|
112
|
+
nfs_path=self.nfs_path,
|
|
113
|
+
host=self.host,
|
|
114
|
+
port=self.port,
|
|
115
|
+
rank_list=self.rank_list,
|
|
116
|
+
tls_path=self.tls_path
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_run_ut_config(self, **config_params):
|
|
120
|
+
return RunUtConfig(
|
|
121
|
+
forward_content=config_params.get('forward_content'),
|
|
122
|
+
backward_content=config_params.get('backward_content'),
|
|
123
|
+
result_csv_path=config_params.get('result_csv_path'),
|
|
124
|
+
details_csv_path=config_params.get('details_csv_path'),
|
|
125
|
+
save_error_data=config_params.get('save_error_data'),
|
|
126
|
+
is_continue_run_ut=config_params.get('is_continue_run_ut'),
|
|
127
|
+
real_data_path=config_params.get('real_data_path'),
|
|
128
|
+
white_list=self.white_list,
|
|
129
|
+
black_list=self.black_list,
|
|
130
|
+
error_data_path=config_params.get('error_data_path'),
|
|
131
|
+
online_config=self.get_online_config()
|
|
132
|
+
)
|