mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,38 +12,33 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
import copy
|
|
18
17
|
import functools
|
|
18
|
+
import os
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
|
-
from mindspore.common.tensor import Tensor
|
|
23
|
-
from mindspore import ops
|
|
24
22
|
from mindspore import nn
|
|
25
23
|
try:
|
|
26
24
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
|
-
pijit_label = True
|
|
28
25
|
except ImportError:
|
|
29
26
|
pijit_label = False
|
|
27
|
+
else:
|
|
28
|
+
pijit_label = True
|
|
30
29
|
|
|
31
30
|
|
|
31
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
32
|
+
from msprobe.core.common.file_utils import create_directory
|
|
33
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
32
34
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
35
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
|
|
33
36
|
from msprobe.core.data_dump.scope import BaseScope
|
|
34
|
-
from msprobe.mindspore.
|
|
35
|
-
from msprobe.core.common.file_utils import create_directory
|
|
37
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
36
38
|
from msprobe.mindspore.common.log import logger
|
|
37
|
-
from msprobe.
|
|
38
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
39
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
39
40
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
40
41
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
41
|
-
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
|
|
42
|
-
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
43
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
44
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
45
|
-
from msprobe.mindspore.cell_processor import CellProcessor
|
|
46
42
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
43
|
|
|
48
44
|
|
|
@@ -79,22 +75,24 @@ class Service:
|
|
|
79
75
|
)
|
|
80
76
|
|
|
81
77
|
def build_hook(self, target_type, name):
|
|
82
|
-
def forward_hook(api_or_cell_name, cell,
|
|
78
|
+
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
83
79
|
if not self.should_excute_hook():
|
|
80
|
+
if hasattr(cell, 'input_kwargs'):
|
|
81
|
+
del cell.input_kwargs
|
|
84
82
|
return None
|
|
85
83
|
|
|
86
84
|
if target_type == BaseScope.Module_Type_Module:
|
|
87
|
-
api_or_cell_name = cell
|
|
88
|
-
module_input_output = ModuleForwardInputsOutputs(args=
|
|
85
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
86
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
89
87
|
else:
|
|
90
|
-
module_input_output = ModuleForwardInputsOutputs(args=
|
|
88
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
|
|
91
89
|
output=output)
|
|
92
90
|
|
|
93
91
|
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
94
92
|
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
95
93
|
if self.data_collector.if_return_forward_new_output():
|
|
96
94
|
return self.data_collector.get_forward_new_output()
|
|
97
|
-
if
|
|
95
|
+
if hasattr(cell, 'input_kwargs'):
|
|
98
96
|
del cell.input_kwargs
|
|
99
97
|
return output
|
|
100
98
|
|
|
@@ -102,12 +100,19 @@ class Service:
|
|
|
102
100
|
if not self.should_excute_hook():
|
|
103
101
|
return
|
|
104
102
|
|
|
103
|
+
need_exchange = True
|
|
105
104
|
if target_type == BaseScope.Module_Type_Module:
|
|
106
|
-
|
|
105
|
+
if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
|
|
106
|
+
need_exchange = False
|
|
107
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
108
|
+
|
|
107
109
|
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
108
110
|
if self.data_collector:
|
|
109
111
|
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
110
|
-
|
|
112
|
+
if need_exchange:
|
|
113
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
114
|
+
else:
|
|
115
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
111
116
|
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
112
117
|
|
|
113
118
|
pid = os.getpid()
|
|
@@ -116,15 +121,14 @@ class Service:
|
|
|
116
121
|
forward_hook = functools.partial(forward_hook, forward_name_template)
|
|
117
122
|
backward_hook = functools.partial(backward_hook, backward_name_template)
|
|
118
123
|
|
|
119
|
-
def wrap_forward_hook(cell,
|
|
120
|
-
return forward_hook(cell,
|
|
124
|
+
def wrap_forward_hook(cell, input_data, output_data):
|
|
125
|
+
return forward_hook(cell, input_data, output_data)
|
|
121
126
|
|
|
122
127
|
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
123
128
|
return backward_hook(cell, grad_input, grad_output)
|
|
124
129
|
|
|
125
130
|
return wrap_forward_hook, wrap_backward_hook
|
|
126
131
|
|
|
127
|
-
|
|
128
132
|
def update_primitive_counters(self, primitive_name):
|
|
129
133
|
if primitive_name not in self.primitive_counters:
|
|
130
134
|
self.primitive_counters[primitive_name] = 0
|
|
@@ -138,15 +142,16 @@ class Service:
|
|
|
138
142
|
primitive_set.add((pname, primitive))
|
|
139
143
|
|
|
140
144
|
for pname, primitive in primitive_set:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
145
|
+
primitive_class_name = primitive.__class__.__name__
|
|
146
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
147
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
148
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
149
|
+
primitive_combined_name)})
|
|
150
|
+
primitive.__class__ = new_primitive
|
|
144
151
|
|
|
145
152
|
def step(self):
|
|
146
153
|
self.current_iter += 1
|
|
147
154
|
self.data_collector.update_iter(self.current_iter)
|
|
148
|
-
HOOKCell.cell_count = defaultdict(int)
|
|
149
|
-
CellProcessor.reset_cell_stats()
|
|
150
155
|
self.primitive_hook_service.primitive_counters.clear()
|
|
151
156
|
self.data_collector.data_writer.reset_cache()
|
|
152
157
|
JitDump.jit_count = defaultdict(int)
|
|
@@ -212,6 +217,7 @@ class Service:
|
|
|
212
217
|
return
|
|
213
218
|
self.primitive_switch = False
|
|
214
219
|
api_register.api_set_ori_func()
|
|
220
|
+
JitDump.jit_dump_switch = False
|
|
215
221
|
|
|
216
222
|
def stop(self):
|
|
217
223
|
if self.should_stop_service:
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const
|
|
2
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
3
18
|
from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
|
msprobe/msprobe.py
CHANGED
|
@@ -45,10 +45,15 @@ def main():
|
|
|
45
45
|
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
46
46
|
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
47
47
|
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
48
|
+
graph_service_cmd_parser = subparsers.add_parser('graph')
|
|
48
49
|
_compare_parser(compare_cmd_parser)
|
|
49
|
-
is_torch_available=is_module_available("torch")
|
|
50
|
+
is_torch_available = is_module_available("torch")
|
|
50
51
|
is_mindspore_available = is_module_available("mindspore")
|
|
51
|
-
if
|
|
52
|
+
if len(sys.argv) < 4:
|
|
53
|
+
parser.print_help()
|
|
54
|
+
sys.exit(0)
|
|
55
|
+
framework_args = parser.parse_args(sys.argv[1:3])
|
|
56
|
+
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
52
57
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
53
58
|
from msprobe.pytorch.parse_tool.cli import parse as cli_parse
|
|
54
59
|
from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
|
|
@@ -56,20 +61,24 @@ def main():
|
|
|
56
61
|
_api_precision_compare_command
|
|
57
62
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
58
63
|
_run_overflow_check_command
|
|
64
|
+
from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
|
|
59
65
|
|
|
60
66
|
_run_ut_parser(run_ut_cmd_parser)
|
|
61
67
|
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
62
68
|
multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
63
|
-
|
|
69
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
64
70
|
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
65
71
|
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
66
|
-
|
|
72
|
+
_pt_graph_service_parser(graph_service_cmd_parser)
|
|
73
|
+
elif framework_args.framework == Const.MS_FRAMEWORK:
|
|
67
74
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
|
|
75
|
+
from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
|
|
68
76
|
add_api_accuracy_checker_argument(run_ut_cmd_parser)
|
|
77
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
|
|
78
|
+
multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
|
|
79
|
+
|
|
80
|
+
_ms_graph_service_parser(graph_service_cmd_parser)
|
|
69
81
|
|
|
70
|
-
if len(sys.argv) == 1:
|
|
71
|
-
parser.print_help()
|
|
72
|
-
sys.exit(0)
|
|
73
82
|
args = parser.parse_args(sys.argv[1:])
|
|
74
83
|
if sys.argv[2] == Const.PT_FRAMEWORK:
|
|
75
84
|
if not is_torch_available:
|
|
@@ -86,6 +95,8 @@ def main():
|
|
|
86
95
|
_api_precision_compare_command(args)
|
|
87
96
|
elif sys.argv[3] == "run_overflow_check":
|
|
88
97
|
_run_overflow_check_command(args)
|
|
98
|
+
elif sys.argv[3] == "graph":
|
|
99
|
+
_pt_graph_service_command(args)
|
|
89
100
|
elif sys.argv[3] == "compare":
|
|
90
101
|
if args.cell_mapping is not None or args.api_mapping is not None:
|
|
91
102
|
logger.error("Argument -cm or -am is not supported in PyTorch framework")
|
|
@@ -100,6 +111,12 @@ def main():
|
|
|
100
111
|
elif sys.argv[3] == "run_ut":
|
|
101
112
|
from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
|
|
102
113
|
api_checker_main(args)
|
|
114
|
+
elif sys.argv[3] == "multi_run_ut":
|
|
115
|
+
from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
|
|
116
|
+
mul_api_checker_main(args)
|
|
117
|
+
elif sys.argv[3] == "graph":
|
|
118
|
+
_ms_graph_service_command(args)
|
|
119
|
+
|
|
103
120
|
|
|
104
121
|
if __name__ == "__main__":
|
|
105
122
|
main()
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -16,8 +16,9 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
from .
|
|
20
|
-
from .common.utils import seed_all
|
|
19
|
+
from msprobe.pytorch.monitor.module_hook import TrainerMon
|
|
21
20
|
from .compare.distributed_compare import compare_distributed
|
|
22
21
|
from .compare.pt_compare import compare
|
|
22
|
+
from .common.utils import seed_all
|
|
23
|
+
from .debugger.precision_debugger import PrecisionDebugger
|
|
23
24
|
from .functional.module_dump import module_dump, module_dump_end
|
|
@@ -16,10 +16,18 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
from collections import namedtuple
|
|
19
20
|
from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
|
|
21
|
+
from msprobe.core.common.utils import is_int
|
|
20
22
|
from msprobe.pytorch.pt_config import RunUTConfig
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
|
+
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
+
'black_list', 'error_data_path', 'online_config'])
|
|
28
|
+
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
29
|
+
|
|
30
|
+
|
|
23
31
|
class Config:
|
|
24
32
|
def __init__(self, yaml_file):
|
|
25
33
|
check_file_or_directory_path(yaml_file, False)
|
|
@@ -50,6 +58,8 @@ class Config:
|
|
|
50
58
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
51
59
|
if not isinstance(value, validators.get(key)):
|
|
52
60
|
raise ValueError(f"{key} must be {validators[key].__name__} type")
|
|
61
|
+
if key == 'precision' and not is_int(value):
|
|
62
|
+
raise ValueError("precision must be an integer")
|
|
53
63
|
if key == 'precision' and (value < 0 or value > 20):
|
|
54
64
|
raise ValueError("precision must be greater than or equal to 0 and less than 21")
|
|
55
65
|
if key == 'white_list':
|
|
@@ -68,3 +78,55 @@ class Config:
|
|
|
68
78
|
cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
69
79
|
yaml_path = os.path.join(cur_path, "config.yaml")
|
|
70
80
|
msCheckerConfig = Config(yaml_path)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CheckerConfig:
|
|
84
|
+
def __init__(self, task_config=None):
|
|
85
|
+
self.white_list = msCheckerConfig.white_list
|
|
86
|
+
self.black_list = msCheckerConfig.black_list
|
|
87
|
+
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
+
self.is_online = msCheckerConfig.is_online
|
|
89
|
+
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
+
self.host = msCheckerConfig.host
|
|
91
|
+
self.port = msCheckerConfig.port
|
|
92
|
+
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
+
self.tls_path = msCheckerConfig.tls_path
|
|
94
|
+
|
|
95
|
+
if task_config:
|
|
96
|
+
self.load_config(task_config)
|
|
97
|
+
|
|
98
|
+
def load_config(self, task_config):
|
|
99
|
+
self.white_list = task_config.white_list
|
|
100
|
+
self.black_list = task_config.black_list
|
|
101
|
+
self.error_data_path = task_config.error_data_path
|
|
102
|
+
self.is_online = task_config.is_online
|
|
103
|
+
self.nfs_path = task_config.nfs_path
|
|
104
|
+
self.host = task_config.host
|
|
105
|
+
self.port = task_config.port
|
|
106
|
+
self.rank_list = task_config.rank_list
|
|
107
|
+
self.tls_path = task_config.tls_path
|
|
108
|
+
|
|
109
|
+
def get_online_config(self):
|
|
110
|
+
return OnlineConfig(
|
|
111
|
+
is_online=self.is_online,
|
|
112
|
+
nfs_path=self.nfs_path,
|
|
113
|
+
host=self.host,
|
|
114
|
+
port=self.port,
|
|
115
|
+
rank_list=self.rank_list,
|
|
116
|
+
tls_path=self.tls_path
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_run_ut_config(self, **config_params):
|
|
120
|
+
return RunUtConfig(
|
|
121
|
+
forward_content=config_params.get('forward_content'),
|
|
122
|
+
backward_content=config_params.get('backward_content'),
|
|
123
|
+
result_csv_path=config_params.get('result_csv_path'),
|
|
124
|
+
details_csv_path=config_params.get('details_csv_path'),
|
|
125
|
+
save_error_data=config_params.get('save_error_data'),
|
|
126
|
+
is_continue_run_ut=config_params.get('is_continue_run_ut'),
|
|
127
|
+
real_data_path=config_params.get('real_data_path'),
|
|
128
|
+
white_list=self.white_list,
|
|
129
|
+
black_list=self.black_list,
|
|
130
|
+
error_data_path=config_params.get('error_data_path'),
|
|
131
|
+
online_config=self.get_online_config()
|
|
132
|
+
)
|
|
@@ -34,7 +34,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECI
|
|
|
34
34
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
|
|
35
35
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
|
|
36
36
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
|
|
37
|
-
from msprobe.core.common.file_utils import FileChecker, change_mode,
|
|
37
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
|
|
38
38
|
from msprobe.pytorch.common.log import logger
|
|
39
39
|
from msprobe.core.common.utils import CompareException
|
|
40
40
|
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
@@ -602,8 +602,7 @@ def _api_precision_compare(parser=None):
|
|
|
602
602
|
def _api_precision_compare_command(args):
|
|
603
603
|
npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
|
|
604
604
|
gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
|
|
605
|
-
out_path =
|
|
606
|
-
check_path_before_create(out_path)
|
|
605
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
607
606
|
create_directory(out_path)
|
|
608
607
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
609
608
|
out_path = out_path_checker.common_check()
|
|
@@ -621,7 +620,7 @@ def _api_precision_compare_parser(parser):
|
|
|
621
620
|
parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
|
|
622
621
|
help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
|
|
623
622
|
"api_accuracy_checker tool.",
|
|
624
|
-
required=
|
|
623
|
+
required=True)
|
|
625
624
|
parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
|
|
626
625
|
help="<optional> The api precision compare task result out path.",
|
|
627
626
|
required=False)
|