mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -46,6 +46,13 @@ class KernelGraphOverflowCheck:
|
|
|
46
46
|
self.dump_json["common_dump_settings"]["op_debug_mode"] = 2
|
|
47
47
|
|
|
48
48
|
def handle(self):
|
|
49
|
+
try:
|
|
50
|
+
from msprobe.lib import _msprobe_c
|
|
51
|
+
return
|
|
52
|
+
except ImportError:
|
|
53
|
+
# 如果没有_msprobe_ce_c走MindSpore老流程
|
|
54
|
+
logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
|
|
55
|
+
|
|
49
56
|
if os.getenv("GRAPH_OP_RUN") == "1":
|
|
50
57
|
raise Exception("Must run in graph mode, not kbk mode")
|
|
51
58
|
json_path = self.dump_json["common_dump_settings"]["path"]
|
msprobe/mindspore/service.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -20,6 +20,8 @@ from collections import defaultdict
|
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
22
|
from mindspore import nn
|
|
23
|
+
from mindspore.common.api import _no_grad
|
|
24
|
+
from mindspore.ops.primitive import Primitive
|
|
23
25
|
try:
|
|
24
26
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
25
27
|
except ImportError:
|
|
@@ -27,19 +29,25 @@ except ImportError:
|
|
|
27
29
|
else:
|
|
28
30
|
pijit_label = True
|
|
29
31
|
|
|
30
|
-
|
|
31
32
|
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
32
33
|
from msprobe.core.common.file_utils import create_directory
|
|
33
34
|
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
34
35
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
35
|
-
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
|
|
36
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
37
|
+
ModuleBackwardInputs)
|
|
36
38
|
from msprobe.core.data_dump.scope import BaseScope
|
|
37
39
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
38
40
|
from msprobe.mindspore.common.log import logger
|
|
39
|
-
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
41
|
+
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
42
|
+
is_mindtorch, register_backward_hook_functions)
|
|
40
43
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
41
44
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
42
45
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
46
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
47
|
+
from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
48
|
+
|
|
49
|
+
if is_mindtorch():
|
|
50
|
+
import torch
|
|
43
51
|
|
|
44
52
|
|
|
45
53
|
class Service:
|
|
@@ -51,54 +59,144 @@ class Service:
|
|
|
51
59
|
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
52
60
|
self.primitive_hook_service = PrimitiveHookService(self)
|
|
53
61
|
self.switch = False
|
|
62
|
+
self.inner_switch = False
|
|
54
63
|
self.primitive_switch = False
|
|
55
64
|
self.current_iter = 0
|
|
56
65
|
self.first_start = True
|
|
57
66
|
self.current_rank = None
|
|
58
67
|
self.dump_iter_dir = None
|
|
59
68
|
self.start_call = False
|
|
60
|
-
self.check_level_valid()
|
|
61
69
|
self.should_stop_service = False
|
|
70
|
+
self.params_grad_info = {}
|
|
71
|
+
# 提前注册,确保注册尽可能多的API hook
|
|
72
|
+
self.register_api_hook()
|
|
62
73
|
|
|
63
74
|
@staticmethod
|
|
64
|
-
def check_model_valid(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
75
|
+
def check_model_valid(models):
|
|
76
|
+
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
77
|
+
if models is None or isinstance(models, target_module_type[0]):
|
|
78
|
+
return models
|
|
79
|
+
error_model = None
|
|
80
|
+
if isinstance(models, (list, tuple)):
|
|
81
|
+
for model in models:
|
|
82
|
+
if not isinstance(model, target_module_type[0]):
|
|
83
|
+
error_model = model
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
error_model = models
|
|
70
87
|
|
|
71
|
-
|
|
72
|
-
|
|
88
|
+
if error_model is not None:
|
|
89
|
+
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
90
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
73
91
|
raise MsprobeException(
|
|
74
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
75
|
-
|
|
92
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
93
|
+
return models
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def prepare_module_input_output(target_type, cell, input_data, output):
|
|
97
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
98
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
99
|
+
else:
|
|
100
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
|
|
101
|
+
return module_input_output
|
|
76
102
|
|
|
77
103
|
def build_hook(self, target_type, name):
|
|
78
|
-
def
|
|
79
|
-
if not self.
|
|
80
|
-
|
|
81
|
-
del cell.input_kwargs
|
|
104
|
+
def pre_hook(api_or_cell_name, cell, input_data):
|
|
105
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
106
|
+
clean_input_kwargs(cell)
|
|
82
107
|
return None
|
|
83
108
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
109
|
+
with _no_grad():
|
|
110
|
+
self.inner_switch = True
|
|
111
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
112
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
113
|
+
else:
|
|
114
|
+
cell.forward_data_collected = True
|
|
115
|
+
HOOKCell.add_cell_count(name)
|
|
116
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
|
|
117
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
118
|
+
self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
119
|
+
self.inner_switch = False
|
|
120
|
+
return input_data
|
|
121
|
+
|
|
122
|
+
def grad_hook(cell, ori_name, param_name):
|
|
123
|
+
def hook_fn(grad):
|
|
124
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
125
|
+
return None
|
|
126
|
+
self.inner_switch = True
|
|
127
|
+
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
128
|
+
self.inner_switch = False
|
|
129
|
+
return None
|
|
90
130
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
131
|
+
return hook_fn
|
|
132
|
+
|
|
133
|
+
def register_param_hook(ori_name, cell, params_dict):
|
|
134
|
+
'''
|
|
135
|
+
注册参数hook
|
|
136
|
+
'''
|
|
137
|
+
# data_mode为forward时,不注册参数hook
|
|
138
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
139
|
+
for param_name, param in params_dict.items():
|
|
140
|
+
if param.requires_grad:
|
|
141
|
+
param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
142
|
+
|
|
143
|
+
def init_params_grad_info(cell, params_dict):
|
|
144
|
+
'''
|
|
145
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
146
|
+
'''
|
|
147
|
+
if not params_dict:
|
|
148
|
+
return
|
|
149
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
150
|
+
grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
|
|
151
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
152
|
+
if not self.params_grad_info.get(grad_name):
|
|
153
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
154
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
155
|
+
if data_info.get(grad_name):
|
|
156
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
157
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
158
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
159
|
+
# 记录当前模块的参数梯度信息已占位
|
|
160
|
+
self.params_grad_info[grad_name] = True
|
|
161
|
+
|
|
162
|
+
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
163
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
164
|
+
clean_input_kwargs(cell)
|
|
165
|
+
return None
|
|
166
|
+
with _no_grad():
|
|
167
|
+
self.inner_switch = True
|
|
168
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
169
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
170
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
171
|
+
params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
|
|
172
|
+
recurse=False).items()}
|
|
173
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
174
|
+
# 判断是否需要注册参数hook
|
|
175
|
+
if not hasattr(cell, 'params_grad_name') and params_dict:
|
|
176
|
+
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
177
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
178
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
179
|
+
setattr(cell, 'params_grad_name', grad_name)
|
|
180
|
+
register_param_hook(ori_name, cell, params_dict)
|
|
181
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
182
|
+
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
183
|
+
init_params_grad_info(cell, params_dict)
|
|
184
|
+
else:
|
|
185
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
186
|
+
self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
187
|
+
|
|
188
|
+
if self.data_collector.if_return_forward_new_output():
|
|
189
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
190
|
+
self.inner_switch = False
|
|
191
|
+
return forward_new_output
|
|
192
|
+
clean_input_kwargs(cell)
|
|
193
|
+
self.inner_switch = False
|
|
194
|
+
return output
|
|
98
195
|
|
|
99
196
|
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
100
|
-
if not self.
|
|
197
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
101
198
|
return
|
|
199
|
+
self.inner_switch = True
|
|
102
200
|
|
|
103
201
|
need_exchange = True
|
|
104
202
|
if target_type == BaseScope.Module_Type_Module:
|
|
@@ -114,12 +212,32 @@ class Service:
|
|
|
114
212
|
else:
|
|
115
213
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
116
214
|
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
215
|
+
self.inner_switch = False
|
|
216
|
+
|
|
217
|
+
def pre_backward_hook(api_or_cell_name, cell, grad_input):
|
|
218
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
219
|
+
return
|
|
220
|
+
self.inner_switch = True
|
|
221
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
222
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
223
|
+
self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
|
|
224
|
+
|
|
225
|
+
self.inner_switch = False
|
|
117
226
|
|
|
118
227
|
pid = os.getpid()
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
228
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
229
|
+
full_forward_name = name + Const.FORWARD
|
|
230
|
+
full_backward_name = name + Const.BACKWARD
|
|
231
|
+
else:
|
|
232
|
+
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
233
|
+
full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
|
|
234
|
+
pre_forward_hook = functools.partial(pre_hook, full_forward_name)
|
|
235
|
+
forward_hook = functools.partial(forward_hook, full_forward_name)
|
|
236
|
+
backward_hook = functools.partial(backward_hook, full_backward_name)
|
|
237
|
+
pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
|
|
238
|
+
|
|
239
|
+
def wrap_pre_forward_hook(cell, input_data):
|
|
240
|
+
return pre_forward_hook(cell, input_data)
|
|
123
241
|
|
|
124
242
|
def wrap_forward_hook(cell, input_data, output_data):
|
|
125
243
|
return forward_hook(cell, input_data, output_data)
|
|
@@ -127,7 +245,10 @@ class Service:
|
|
|
127
245
|
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
128
246
|
return backward_hook(cell, grad_input, grad_output)
|
|
129
247
|
|
|
130
|
-
|
|
248
|
+
def wrap_pre_backward_hook(cell, grad_input):
|
|
249
|
+
return pre_backward_hook(cell, grad_input)
|
|
250
|
+
|
|
251
|
+
return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
|
|
131
252
|
|
|
132
253
|
def update_primitive_counters(self, primitive_name):
|
|
133
254
|
if primitive_name not in self.primitive_counters:
|
|
@@ -135,33 +256,20 @@ class Service:
|
|
|
135
256
|
else:
|
|
136
257
|
self.primitive_counters[primitive_name] += 1
|
|
137
258
|
|
|
138
|
-
def register_primitive_hooks(self):
|
|
139
|
-
primitive_set = set()
|
|
140
|
-
for _, cell in self.model.cells_and_names():
|
|
141
|
-
for pname, primitive in cell._primitives.items():
|
|
142
|
-
primitive_set.add((pname, primitive))
|
|
143
|
-
|
|
144
|
-
for pname, primitive in primitive_set:
|
|
145
|
-
primitive_class_name = primitive.__class__.__name__
|
|
146
|
-
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
147
|
-
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
148
|
-
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
149
|
-
primitive_combined_name)})
|
|
150
|
-
primitive.__class__ = new_primitive
|
|
151
|
-
|
|
152
259
|
def step(self):
|
|
260
|
+
if self.config.async_dump:
|
|
261
|
+
self.data_collector.fill_stack_tensor_data()
|
|
262
|
+
self.data_collector.data_processor.dump_async_data()
|
|
263
|
+
self.data_collector.write_json()
|
|
153
264
|
self.current_iter += 1
|
|
154
265
|
self.data_collector.update_iter(self.current_iter)
|
|
155
|
-
self.
|
|
156
|
-
self.data_collector.data_writer.reset_cache()
|
|
157
|
-
JitDump.jit_count = defaultdict(int)
|
|
266
|
+
self.reset_status()
|
|
158
267
|
|
|
159
268
|
def start(self, model=None):
|
|
160
269
|
self.start_call = True
|
|
161
270
|
if self.should_stop_service:
|
|
162
271
|
return
|
|
163
272
|
if self.need_end_service():
|
|
164
|
-
api_register.api_set_ori_func()
|
|
165
273
|
self.should_stop_service = True
|
|
166
274
|
self.switch = False
|
|
167
275
|
self.primitive_switch = False
|
|
@@ -181,7 +289,8 @@ class Service:
|
|
|
181
289
|
|
|
182
290
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
183
291
|
return
|
|
184
|
-
self.
|
|
292
|
+
self.register_primitive_hook()
|
|
293
|
+
self.register_cell_hook()
|
|
185
294
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
186
295
|
JitDump.set_config(self.config)
|
|
187
296
|
JitDump.set_data_collector(self.data_collector)
|
|
@@ -200,25 +309,6 @@ class Service:
|
|
|
200
309
|
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
201
310
|
JitDump.jit_dump_switch = True
|
|
202
311
|
|
|
203
|
-
def forward_backward_dump_end(self):
|
|
204
|
-
if self.should_stop_service:
|
|
205
|
-
return
|
|
206
|
-
logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
|
|
207
|
-
if not self.start_call:
|
|
208
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
209
|
-
raise Exception("debugger.start() is not set in the current scope.")
|
|
210
|
-
if not self.switch:
|
|
211
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
|
|
212
|
-
"debugger.start() and debugger.stop() ")
|
|
213
|
-
raise Exception("debugger.stop() is already called. ")
|
|
214
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
215
|
-
return
|
|
216
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
217
|
-
return
|
|
218
|
-
self.primitive_switch = False
|
|
219
|
-
api_register.api_set_ori_func()
|
|
220
|
-
JitDump.jit_dump_switch = False
|
|
221
|
-
|
|
222
312
|
def stop(self):
|
|
223
313
|
if self.should_stop_service:
|
|
224
314
|
return
|
|
@@ -234,6 +324,9 @@ class Service:
|
|
|
234
324
|
self.switch = False
|
|
235
325
|
self.primitive_switch = False
|
|
236
326
|
self.start_call = False
|
|
327
|
+
if self.config.async_dump:
|
|
328
|
+
self.data_collector.fill_stack_tensor_data()
|
|
329
|
+
self.data_collector.data_processor.dump_async_data()
|
|
237
330
|
self.data_collector.write_json()
|
|
238
331
|
JitDump.jit_dump_switch = False
|
|
239
332
|
|
|
@@ -244,8 +337,16 @@ class Service:
|
|
|
244
337
|
return True
|
|
245
338
|
return False
|
|
246
339
|
|
|
247
|
-
def
|
|
248
|
-
|
|
340
|
+
def should_execute_hook(self, hook_type, cell, is_forward):
|
|
341
|
+
is_cell_hook = hook_type == BaseScope.Module_Type_Module
|
|
342
|
+
if is_cell_hook and not self.switch:
|
|
343
|
+
return False
|
|
344
|
+
elif not is_cell_hook and is_forward and not self.switch:
|
|
345
|
+
return False
|
|
346
|
+
elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
|
|
347
|
+
return False
|
|
348
|
+
|
|
349
|
+
if self.inner_switch:
|
|
249
350
|
return False
|
|
250
351
|
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
251
352
|
return False
|
|
@@ -255,6 +356,12 @@ class Service:
|
|
|
255
356
|
create_directory(self.config.dump_path)
|
|
256
357
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
257
358
|
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
359
|
+
if self.config.level == Const.LEVEL_L2:
|
|
360
|
+
create_directory(self.dump_iter_dir)
|
|
361
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
362
|
+
self.config.kernel_config_path = kernel_config_path
|
|
363
|
+
return
|
|
364
|
+
|
|
258
365
|
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
259
366
|
create_directory(dump_dir)
|
|
260
367
|
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
@@ -267,37 +374,96 @@ class Service:
|
|
|
267
374
|
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
268
375
|
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
269
376
|
self.data_collector.update_dump_paths(
|
|
270
|
-
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
|
|
377
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
|
|
378
|
+
)
|
|
379
|
+
self.data_collector.initialize_json_file(
|
|
380
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
381
|
+
)
|
|
271
382
|
|
|
272
383
|
def empty(self, *args, **kwargs):
|
|
273
384
|
pass
|
|
274
385
|
|
|
275
|
-
def
|
|
276
|
-
|
|
277
|
-
|
|
386
|
+
def register_api_hook(self):
|
|
387
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
388
|
+
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
278
389
|
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
279
390
|
api_register.api_set_hook_func()
|
|
280
|
-
if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
|
|
281
|
-
self.register_primitive_hooks()
|
|
282
391
|
|
|
392
|
+
def get_cells_and_names(self):
|
|
393
|
+
cells_and_names_with_index = {}
|
|
394
|
+
|
|
395
|
+
def get_cell_or_module(model):
|
|
396
|
+
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
397
|
+
|
|
398
|
+
if isinstance(self.model, (list, tuple)):
|
|
399
|
+
for index, model in enumerate(self.model):
|
|
400
|
+
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
401
|
+
else:
|
|
402
|
+
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
403
|
+
return cells_and_names_with_index
|
|
404
|
+
|
|
405
|
+
def register_primitive_hook(self):
|
|
406
|
+
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
407
|
+
return
|
|
408
|
+
if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
primitive_set = set()
|
|
412
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
413
|
+
for cells_and_names in cells_and_names_with_index.values():
|
|
414
|
+
for _, cell in cells_and_names:
|
|
415
|
+
for attribute, value in vars(cell).items():
|
|
416
|
+
if isinstance(value, Primitive):
|
|
417
|
+
primitive_set.add((attribute, value))
|
|
418
|
+
|
|
419
|
+
for pname, primitive in primitive_set:
|
|
420
|
+
primitive_class_name = primitive.__class__.__name__
|
|
421
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
422
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
423
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
424
|
+
primitive_combined_name)})
|
|
425
|
+
primitive.__class__ = new_primitive
|
|
426
|
+
|
|
427
|
+
def register_cell_hook(self):
|
|
283
428
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
429
|
+
logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
|
|
284
430
|
if not self.model:
|
|
285
431
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
286
432
|
f"The current level is {self.config.level}, the model cannot be None")
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
433
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
434
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
435
|
+
|
|
436
|
+
for index, cells_and_names in cells_and_names_with_index.items():
|
|
437
|
+
model = self.model if index == "-1" else self.model[int(index)]
|
|
438
|
+
for name, cell in cells_and_names:
|
|
439
|
+
if cell == model:
|
|
440
|
+
continue
|
|
441
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
442
|
+
prefix = (model_type + Const.SEP + cell_index + name +
|
|
443
|
+
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
444
|
+
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
445
|
+
cell.register_forward_hook(forward_hook)
|
|
446
|
+
cell.register_forward_pre_hook(
|
|
447
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
448
|
+
cell.register_forward_hook(
|
|
449
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
450
|
+
|
|
451
|
+
register_backward_hook_functions["full"](cell, backward_hook)
|
|
452
|
+
register_backward_hook_functions["pre"](
|
|
453
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
454
|
+
register_backward_hook_functions["full"](
|
|
455
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
456
|
+
|
|
457
|
+
def reset_status(self):
|
|
458
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
459
|
+
self.data_collector.data_writer.reset_cache()
|
|
460
|
+
JitDump.jit_count = defaultdict(int)
|
|
461
|
+
self.params_grad_info.clear()
|
|
462
|
+
|
|
463
|
+
if self.config.level == Const.LEVEL_L2:
|
|
464
|
+
self.data_collector.data_processor.reset_status()
|
|
465
|
+
return
|
|
466
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
467
|
+
return
|
|
468
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
469
|
+
return
|
msprobe/msprobe.py
CHANGED
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
import argparse
|
|
17
17
|
import sys
|
|
18
18
|
import importlib.util
|
|
19
|
-
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
21
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
22
|
-
from msprobe.core.
|
|
24
|
+
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
def is_module_available(module_name):
|
|
@@ -45,10 +47,15 @@ def main():
|
|
|
45
47
|
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
46
48
|
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
47
49
|
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
50
|
+
code_mapping_cmd_parser = subparsers.add_parser('code_mapping')
|
|
48
51
|
graph_service_cmd_parser = subparsers.add_parser('graph')
|
|
52
|
+
op_generate_cmd_parser = subparsers.add_parser('op_generate')
|
|
53
|
+
merge_result_parser = subparsers.add_parser('merge_result')
|
|
49
54
|
_compare_parser(compare_cmd_parser)
|
|
55
|
+
_merge_result_parser(merge_result_parser)
|
|
56
|
+
|
|
50
57
|
is_torch_available = is_module_available("torch")
|
|
51
|
-
|
|
58
|
+
|
|
52
59
|
if len(sys.argv) < 4:
|
|
53
60
|
parser.print_help()
|
|
54
61
|
sys.exit(0)
|
|
@@ -62,6 +69,8 @@ def main():
|
|
|
62
69
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
63
70
|
_run_overflow_check_command
|
|
64
71
|
from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
|
|
72
|
+
from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
|
|
73
|
+
_run_operator_generate_commond
|
|
65
74
|
|
|
66
75
|
_run_ut_parser(run_ut_cmd_parser)
|
|
67
76
|
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
@@ -70,12 +79,15 @@ def main():
|
|
|
70
79
|
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
71
80
|
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
72
81
|
_pt_graph_service_parser(graph_service_cmd_parser)
|
|
82
|
+
_op_generator_parser(op_generate_cmd_parser)
|
|
73
83
|
elif framework_args.framework == Const.MS_FRAMEWORK:
|
|
74
84
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
|
|
75
85
|
from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
|
|
76
86
|
add_api_accuracy_checker_argument(run_ut_cmd_parser)
|
|
77
87
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
|
|
78
88
|
multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
|
|
89
|
+
from msprobe.mindspore.code_mapping.cmd_parser import add_ir_parser_arguments
|
|
90
|
+
add_ir_parser_arguments(code_mapping_cmd_parser)
|
|
79
91
|
|
|
80
92
|
_ms_graph_service_parser(graph_service_cmd_parser)
|
|
81
93
|
|
|
@@ -97,17 +109,23 @@ def main():
|
|
|
97
109
|
_run_overflow_check_command(args)
|
|
98
110
|
elif sys.argv[3] == "graph":
|
|
99
111
|
_pt_graph_service_command(args)
|
|
112
|
+
elif sys.argv[3] == 'op_generate':
|
|
113
|
+
_run_operator_generate_commond(args)
|
|
100
114
|
elif sys.argv[3] == "compare":
|
|
101
115
|
if args.cell_mapping is not None or args.api_mapping is not None:
|
|
102
116
|
logger.error("Argument -cm or -am is not supported in PyTorch framework")
|
|
103
117
|
raise Exception("Argument -cm or -am is not supported in PyTorch framework")
|
|
104
118
|
compare_cli(args)
|
|
119
|
+
elif sys.argv[3] == "merge_result":
|
|
120
|
+
merge_result_cli(args)
|
|
105
121
|
else:
|
|
106
122
|
if not is_module_available(Const.MS_FRAMEWORK):
|
|
107
123
|
logger.error("MindSpore does not exist, please install MindSpore library")
|
|
108
124
|
raise Exception("MindSpore does not exist, please install MindSpore library")
|
|
109
125
|
if sys.argv[3] == "compare":
|
|
110
126
|
compare_cli(args)
|
|
127
|
+
elif sys.argv[3] == "merge_result":
|
|
128
|
+
merge_result_cli(args)
|
|
111
129
|
elif sys.argv[3] == "run_ut":
|
|
112
130
|
from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
|
|
113
131
|
api_checker_main(args)
|
|
@@ -116,6 +134,9 @@ def main():
|
|
|
116
134
|
mul_api_checker_main(args)
|
|
117
135
|
elif sys.argv[3] == "graph":
|
|
118
136
|
_ms_graph_service_command(args)
|
|
137
|
+
elif sys.argv[3] == "code_mapping":
|
|
138
|
+
from msprobe.mindspore.code_mapping.main import code_mapping_main
|
|
139
|
+
code_mapping_main(args)
|
|
119
140
|
|
|
120
141
|
|
|
121
142
|
if __name__ == "__main__":
|
msprobe/pytorch/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
2
|
# All rights reserved.
|
|
5
3
|
#
|
|
6
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +14,12 @@
|
|
|
16
14
|
# limitations under the License.
|
|
17
15
|
|
|
18
16
|
|
|
19
|
-
|
|
17
|
+
import torch
|
|
20
18
|
from .compare.distributed_compare import compare_distributed
|
|
21
19
|
from .compare.pt_compare import compare
|
|
22
20
|
from .common.utils import seed_all
|
|
23
|
-
from .debugger.precision_debugger import PrecisionDebugger
|
|
24
|
-
|
|
21
|
+
from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end
|
|
22
|
+
|
|
23
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
|
+
if torch_version_above_or_equal_2:
|
|
25
|
+
from msprobe.pytorch.monitor.module_hook import TrainerMon
|