mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -70,6 +70,67 @@ class Const:
|
|
|
70
70
|
}
|
|
71
71
|
|
|
72
72
|
|
|
73
|
+
class MsCompareConst:
|
|
74
|
+
# api_info field
|
|
75
|
+
MINT = "Mint"
|
|
76
|
+
MINT_FUNCTIONAL = "MintFunctional"
|
|
77
|
+
TENSOR_API = "Tensor"
|
|
78
|
+
FUNCTIONAL_API = "Functional"
|
|
79
|
+
FUSION_API = "FUSION"
|
|
80
|
+
|
|
81
|
+
API_NAME_STR_LENGTH = 4
|
|
82
|
+
MAX_RECURSION_DEPTH = 20
|
|
83
|
+
|
|
84
|
+
# Mindtorch api_info field
|
|
85
|
+
MINDTORCH_TENSOR = "Tensor"
|
|
86
|
+
MINDTORCH = "Torch"
|
|
87
|
+
MINDTORCH_FUNC = "Functional"
|
|
88
|
+
MINDTORCH_NPU = "NPU"
|
|
89
|
+
MINDTORCH_DIST = "Distributed"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
MT_VALID_API_TYPES = [
|
|
94
|
+
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
95
|
+
]
|
|
96
|
+
SUPPORTED_FUSION_LIST = ["flash_attention_score"]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
TASK_FIELD = "task"
|
|
100
|
+
STATISTICS_TASK = "statistics"
|
|
101
|
+
FRAMEWORK = "framework"
|
|
102
|
+
TENSOR_TASK = "tensor"
|
|
103
|
+
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
104
|
+
DATA_FIELD = "data"
|
|
105
|
+
|
|
106
|
+
# supported api yaml
|
|
107
|
+
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
108
|
+
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
109
|
+
|
|
110
|
+
# detail_csv
|
|
111
|
+
DETAIL_CSV_API_NAME = "API Name"
|
|
112
|
+
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
113
|
+
DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
|
|
114
|
+
DETAIL_CSV_SHAPE = "Shape"
|
|
115
|
+
DETAIL_CSV_PASS_STATUS = "Status"
|
|
116
|
+
DETAIL_CSV_MESSAGE = "Message"
|
|
117
|
+
DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
|
|
118
|
+
|
|
119
|
+
# result_csv
|
|
120
|
+
RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
|
|
121
|
+
RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
|
|
122
|
+
RESULT_CSV_FILE_NAME = "accuracy_checking_result"
|
|
123
|
+
|
|
124
|
+
EPSILON = 1e-8
|
|
125
|
+
|
|
126
|
+
class ProcessStatus:
|
|
127
|
+
SUCCESS = "success"
|
|
128
|
+
API_NOT_FOUND = "api_not_found"
|
|
129
|
+
EXCEPTION_SKIP = "exception_skip"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
|
|
73
134
|
class FreeBenchmarkConst:
|
|
74
135
|
ADD_NOISE = "add_noise"
|
|
75
136
|
BIT_NOISE = "bit_noise"
|
|
@@ -25,7 +25,31 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
|
25
25
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
26
26
|
from msprobe.core.common.log import logger
|
|
27
27
|
from msprobe.core.common.const import Const
|
|
28
|
-
from msprobe.core.common.utils import CompareException, check_seed_all
|
|
28
|
+
from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MsprobeStep(ms.train.Callback):
|
|
32
|
+
def __init__(self, debugger):
|
|
33
|
+
super(MsprobeStep, self).__init__()
|
|
34
|
+
self.debugger = debugger
|
|
35
|
+
|
|
36
|
+
def on_train_step_begin(self, run_context):
|
|
37
|
+
self.debugger.start()
|
|
38
|
+
|
|
39
|
+
def on_train_step_end(self, run_context):
|
|
40
|
+
self.debugger.stop()
|
|
41
|
+
self.debugger.step()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MsprobeInitStep(ms.train.Callback):
|
|
45
|
+
def on_train_begin(self, run_context):
|
|
46
|
+
try:
|
|
47
|
+
from ms._c_expression import _set_init_iter
|
|
48
|
+
except ImportError:
|
|
49
|
+
logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
|
|
50
|
+
return
|
|
51
|
+
cb_params = run_context.original_args()
|
|
52
|
+
_set_init_iter(cb_params.cur_step_num)
|
|
29
53
|
|
|
30
54
|
|
|
31
55
|
def get_rank_if_initialized():
|
|
@@ -93,20 +117,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
|
93
117
|
remove_dropout()
|
|
94
118
|
|
|
95
119
|
|
|
96
|
-
class MsprobeStep(ms.train.Callback):
|
|
97
|
-
|
|
98
|
-
def __init__(self, debugger):
|
|
99
|
-
super(MsprobeStep, self).__init__()
|
|
100
|
-
self.debugger = debugger
|
|
101
|
-
|
|
102
|
-
def on_train_step_begin(self, run_context):
|
|
103
|
-
self.debugger.start()
|
|
104
|
-
|
|
105
|
-
def on_train_step_end(self, run_context):
|
|
106
|
-
self.debugger.stop()
|
|
107
|
-
self.debugger.step()
|
|
108
|
-
|
|
109
|
-
|
|
110
120
|
class Dropout(ops.Dropout):
|
|
111
121
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
|
|
112
122
|
super().__init__(1., seed0, seed1)
|
|
@@ -169,7 +179,7 @@ def set_register_backward_hook_functions():
|
|
|
169
179
|
from msprobe.mindspore.mindtorch import (_call_impl,
|
|
170
180
|
register_full_backward_pre_hook,
|
|
171
181
|
register_full_backward_hook)
|
|
172
|
-
if not hasattr(torch, "register_full_backward_hook"):
|
|
182
|
+
if not hasattr(torch.nn.Module, "register_full_backward_hook"):
|
|
173
183
|
setattr(torch.nn.Module, "_call_impl", _call_impl)
|
|
174
184
|
setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
|
|
175
185
|
setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
|
|
@@ -182,9 +192,11 @@ def set_register_backward_hook_functions():
|
|
|
182
192
|
|
|
183
193
|
def check_save_param(variable, name, save_backward):
|
|
184
194
|
# try catch this api to skip invalid call
|
|
185
|
-
|
|
195
|
+
valid_data_types = tuple([ms.Tensor, int, float, str])
|
|
196
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
197
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
186
198
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
187
|
-
"should be one of
|
|
199
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
188
200
|
"Skip current save process.")
|
|
189
201
|
raise ValueError
|
|
190
202
|
if not isinstance(name, str):
|
|
@@ -196,4 +208,4 @@ def check_save_param(variable, name, save_backward):
|
|
|
196
208
|
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
197
209
|
"should be bool. "
|
|
198
210
|
"Skip current save process.")
|
|
199
|
-
raise ValueError
|
|
211
|
+
raise ValueError
|
|
@@ -22,10 +22,10 @@ import pandas as pd
|
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import CompareConst, Const
|
|
24
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
25
|
-
from msprobe.core.common.file_utils import
|
|
25
|
+
from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
|
|
26
26
|
from msprobe.core.common.log import logger
|
|
27
27
|
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
|
|
28
|
-
check_op_str_pattern_valid, get_dump_mode, set_dump_path
|
|
28
|
+
check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
|
|
29
29
|
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
30
30
|
from msprobe.core.compare.check import dtype_mapping
|
|
31
31
|
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
@@ -78,6 +78,11 @@ class MSComparator(Comparator):
|
|
|
78
78
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
79
79
|
f"{type(self.data_mapping)}")
|
|
80
80
|
|
|
81
|
+
@staticmethod
|
|
82
|
+
def process_data_name(result):
|
|
83
|
+
result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
|
|
84
|
+
return result
|
|
85
|
+
|
|
81
86
|
def calc_accuracy(self, result_df, header):
|
|
82
87
|
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
83
88
|
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
@@ -120,12 +125,13 @@ class MSComparator(Comparator):
|
|
|
120
125
|
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
121
126
|
elif self.dump_mode == Const.SUMMARY:
|
|
122
127
|
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
123
|
-
warning_flag = pd.DataFrame(warning_list).
|
|
128
|
+
warning_flag = pd.DataFrame(warning_list).any()
|
|
124
129
|
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
125
130
|
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
126
131
|
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
127
132
|
else:
|
|
128
|
-
fill_cols = [CompareConst.COSINE, CompareConst.
|
|
133
|
+
fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
|
|
134
|
+
CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
129
135
|
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
130
136
|
CompareConst.ERROR_MESSAGE]
|
|
131
137
|
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
@@ -139,6 +145,8 @@ class MSComparator(Comparator):
|
|
|
139
145
|
header.append(CompareConst.STACK)
|
|
140
146
|
if self.dump_mode == Const.ALL:
|
|
141
147
|
header.append(CompareConst.DATA_NAME)
|
|
148
|
+
result = self.process_data_name(result)
|
|
149
|
+
|
|
142
150
|
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
143
151
|
'op_name_y': CompareConst.BENCH_NAME,
|
|
144
152
|
'dtype_x': CompareConst.NPU_DTYPE,
|
|
@@ -169,6 +177,7 @@ class MSComparator(Comparator):
|
|
|
169
177
|
|
|
170
178
|
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
171
179
|
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
180
|
+
|
|
172
181
|
result_df = pd.DataFrame(columns=header)
|
|
173
182
|
for h in header:
|
|
174
183
|
if h in result.columns:
|
|
@@ -269,15 +278,15 @@ class MSComparator(Comparator):
|
|
|
269
278
|
bench_dtype = match_result['dtype_y']
|
|
270
279
|
if self.cross_frame:
|
|
271
280
|
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
(
|
|
276
|
-
|
|
277
|
-
(
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
+
|
|
282
|
+
equal_condition = npu_dtype == bench_dtype
|
|
283
|
+
match_condition = (
|
|
284
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
|
|
285
|
+
CompareConst.DTYPE_MATCH_GROUPS[0])) |
|
|
286
|
+
(npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
|
|
287
|
+
CompareConst.DTYPE_MATCH_GROUPS[1]))
|
|
288
|
+
)
|
|
289
|
+
return equal_condition | match_condition
|
|
281
290
|
|
|
282
291
|
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
283
292
|
return self.make_result_df(match_result)
|
|
@@ -382,12 +391,11 @@ class MSComparator(Comparator):
|
|
|
382
391
|
|
|
383
392
|
|
|
384
393
|
def check_cross_framework(bench_json_path):
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
return False
|
|
394
|
+
framework = detect_framework_by_dump_json(bench_json_path)
|
|
395
|
+
if framework == Const.PT_FRAMEWORK:
|
|
396
|
+
return True
|
|
397
|
+
else:
|
|
398
|
+
return False
|
|
391
399
|
|
|
392
400
|
|
|
393
401
|
def ms_compare(input_param, output_path, **kwargs):
|
|
@@ -195,11 +195,12 @@ class GraphMSComparator:
|
|
|
195
195
|
if not error_flag:
|
|
196
196
|
result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
|
|
197
197
|
result_dict[CompareConst.COSINE] = result_list[0]
|
|
198
|
-
result_dict[CompareConst.
|
|
199
|
-
result_dict[CompareConst.
|
|
200
|
-
result_dict[CompareConst.
|
|
201
|
-
result_dict[CompareConst.
|
|
202
|
-
result_dict[CompareConst.
|
|
198
|
+
result_dict[CompareConst.EUC_DIST] = result_list[1]
|
|
199
|
+
result_dict[CompareConst.MAX_ABS_ERR] = result_list[2]
|
|
200
|
+
result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3]
|
|
201
|
+
result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4]
|
|
202
|
+
result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5]
|
|
203
|
+
result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2])
|
|
203
204
|
result_dict[CompareConst.ERROR_MESSAGE] = err_msg
|
|
204
205
|
|
|
205
206
|
return pd.Series(result_dict)
|
|
@@ -53,11 +53,13 @@ class DebuggerConfig:
|
|
|
53
53
|
self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage
|
|
54
54
|
if self.handler_type == FreeBenchmarkConst.FIX and \
|
|
55
55
|
self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
|
|
56
|
-
|
|
57
|
-
|
|
56
|
+
logger.error("pert_mode must be improve_precision or empty when handler_type is fix, "
|
|
57
|
+
f"but got {self.pert_type}.")
|
|
58
|
+
raise ValueError
|
|
58
59
|
if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
|
|
59
|
-
|
|
60
|
-
|
|
60
|
+
logger.error("handler_type must be check or empty when fuzz_stage is backward, "
|
|
61
|
+
f"but got {self.handler_type}.")
|
|
62
|
+
raise ValueError
|
|
61
63
|
self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
|
|
62
64
|
|
|
63
65
|
def check(self):
|
|
@@ -22,12 +22,12 @@ from mindspore._c_expression import MSContext
|
|
|
22
22
|
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
23
23
|
from msprobe.core.common.exceptions import MsprobeException
|
|
24
24
|
from msprobe.core.common.file_utils import FileChecker
|
|
25
|
-
from msprobe.core.common.utils import get_real_step_or_rank
|
|
25
|
+
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
|
|
26
26
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
27
27
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
28
28
|
from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
|
|
29
29
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
30
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
31
31
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
32
32
|
from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
|
|
33
33
|
from msprobe.mindspore.ms_config import parse_json_config
|
|
@@ -84,7 +84,7 @@ class PrecisionDebugger:
|
|
|
84
84
|
common_config.dump_path = dump_path if dump_path else common_config.dump_path
|
|
85
85
|
self.config = DebuggerConfig(common_config, task_config)
|
|
86
86
|
|
|
87
|
-
if _msprobe_c:
|
|
87
|
+
if self._need_msprobe_c() and _msprobe_c:
|
|
88
88
|
_msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
|
|
89
89
|
|
|
90
90
|
self.config.execution_mode = self._get_execution_mode()
|
|
@@ -151,7 +151,7 @@ class PrecisionDebugger:
|
|
|
151
151
|
instance = cls._instance
|
|
152
152
|
if not instance:
|
|
153
153
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
154
|
-
if _msprobe_c:
|
|
154
|
+
if cls._need_msprobe_c() and _msprobe_c:
|
|
155
155
|
_msprobe_c._PrecisionDebugger().start()
|
|
156
156
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
157
157
|
return
|
|
@@ -163,7 +163,7 @@ class PrecisionDebugger:
|
|
|
163
163
|
instance.service.start(model)
|
|
164
164
|
else:
|
|
165
165
|
if not instance.first_start:
|
|
166
|
-
|
|
166
|
+
get_api_register().restore_all_api()
|
|
167
167
|
handler = TaskHandlerFactory.create(instance.config)
|
|
168
168
|
handler.handle()
|
|
169
169
|
|
|
@@ -180,8 +180,6 @@ class PrecisionDebugger:
|
|
|
180
180
|
instance = cls._instance
|
|
181
181
|
if not instance:
|
|
182
182
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
183
|
-
if _msprobe_c:
|
|
184
|
-
_msprobe_c._PrecisionDebugger().stop()
|
|
185
183
|
if instance.task == Const.GRAD_PROBE:
|
|
186
184
|
instance.gm.stop()
|
|
187
185
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
@@ -195,8 +193,6 @@ class PrecisionDebugger:
|
|
|
195
193
|
instance = cls._instance
|
|
196
194
|
if not instance:
|
|
197
195
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
198
|
-
if _msprobe_c:
|
|
199
|
-
_msprobe_c._PrecisionDebugger().step()
|
|
200
196
|
if instance.task in PrecisionDebugger.task_not_need_service:
|
|
201
197
|
return
|
|
202
198
|
if instance.service:
|
|
@@ -233,6 +229,15 @@ class PrecisionDebugger:
|
|
|
233
229
|
instance.service = Service(instance.config)
|
|
234
230
|
instance.service.save(variable, name, save_backward)
|
|
235
231
|
|
|
232
|
+
@classmethod
|
|
233
|
+
def set_init_step(cls, step):
|
|
234
|
+
instance = cls._instance
|
|
235
|
+
if not instance:
|
|
236
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
237
|
+
check_init_step(step)
|
|
238
|
+
instance.service.init_step = step
|
|
239
|
+
instance.service.loop = 0
|
|
240
|
+
|
|
236
241
|
@classmethod
|
|
237
242
|
def _need_service(cls):
|
|
238
243
|
instance = cls._instance
|
|
@@ -241,4 +246,11 @@ class PrecisionDebugger:
|
|
|
241
246
|
if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
|
|
242
247
|
return False
|
|
243
248
|
else:
|
|
244
|
-
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
249
|
+
return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def _need_msprobe_c(cls):
|
|
253
|
+
instance = cls._instance
|
|
254
|
+
if not instance:
|
|
255
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
256
|
+
return instance.config.level_ori == Const.LEVEL_L2
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.mindspore.common.const import Const
|
|
17
|
+
from msprobe.core.common.log import logger
|
|
17
18
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
18
19
|
from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
|
|
19
20
|
from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
|
|
@@ -47,6 +48,7 @@ class DumpToolFactory:
|
|
|
47
48
|
raise Exception("Valid level is needed.")
|
|
48
49
|
tool = tool.get(config.execution_mode)
|
|
49
50
|
if not tool:
|
|
50
|
-
|
|
51
|
-
|
|
51
|
+
logger.error(f"Data dump is not supported in {config.execution_mode} mode "
|
|
52
|
+
f"when dump level is {config.level}.")
|
|
53
|
+
raise ValueError
|
|
52
54
|
return tool(config)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from mindspore import Tensor, ops, mint
|
|
19
|
+
from mindspore.mint.nn import functional
|
|
20
|
+
from mindspore.communication import comm_func
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
23
|
+
from msprobe.core.common.utils import Const
|
|
24
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
25
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
27
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
stub_tensor_existed = True
|
|
31
|
+
try:
|
|
32
|
+
from mindspore.common._stub_tensor import StubTensor
|
|
33
|
+
except ImportError:
|
|
34
|
+
stub_tensor_existed = False
|
|
35
|
+
|
|
36
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
37
|
+
if not is_mindtorch():
|
|
38
|
+
_api_types = {
|
|
39
|
+
Const.MS_FRAMEWORK: {
|
|
40
|
+
Const.MS_API_TYPE_OPS: (ops, (ops,)),
|
|
41
|
+
Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
|
|
42
|
+
Const.MS_API_TYPE_MINT: (mint, (mint,)),
|
|
43
|
+
Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
|
|
44
|
+
Const.MS_API_TYPE_COM: (comm_func, (comm_func,))
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
if stub_tensor_existed:
|
|
48
|
+
_api_types.get(Const.MS_FRAMEWORK).update(
|
|
49
|
+
{Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
_supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
|
|
53
|
+
else:
|
|
54
|
+
import torch
|
|
55
|
+
import torch_npu
|
|
56
|
+
_api_types = {
|
|
57
|
+
Const.MT_FRAMEWORK: {
|
|
58
|
+
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
59
|
+
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
60
|
+
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
61
|
+
Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
|
|
62
|
+
Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
_supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
|
|
66
|
+
MsConst.SUPPORTED_API_LIST_FILE),)
|
|
67
|
+
|
|
68
|
+
_inner_used_api = {
|
|
69
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
|
|
70
|
+
ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
|
|
71
|
+
),
|
|
72
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
|
|
73
|
+
Tensor, "to", "numel"
|
|
74
|
+
),
|
|
75
|
+
Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
|
|
76
|
+
mint, "max", "min", "mean", "norm"
|
|
77
|
+
)
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ApiTemplate(HOOKCell):
|
|
82
|
+
def __init__(self, api_name, api_func, prefix, hook_build_func):
|
|
83
|
+
self.api_name = api_name
|
|
84
|
+
self.api_func = api_func
|
|
85
|
+
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
86
|
+
super().__init__(hook_build_func)
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def async_to_sync(output):
|
|
90
|
+
# Fake handle, used to return after the CommHandle executes the wait method
|
|
91
|
+
fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
|
|
92
|
+
if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
|
|
93
|
+
output[1].wait()
|
|
94
|
+
output = (output[0], fake_handle)
|
|
95
|
+
elif hasattr(output, "wait"):
|
|
96
|
+
output.wait()
|
|
97
|
+
output = fake_handle
|
|
98
|
+
return output
|
|
99
|
+
|
|
100
|
+
def construct(self, *args, **kwargs):
|
|
101
|
+
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
102
|
+
return args[0] if args else kwargs.get(Const.INPUT)
|
|
103
|
+
|
|
104
|
+
output = self.api_func(*args, **kwargs)
|
|
105
|
+
|
|
106
|
+
if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
|
|
107
|
+
if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
|
|
108
|
+
output = self.async_to_sync(output)
|
|
109
|
+
return output
|
|
110
|
+
|
|
111
|
+
def forward(self, *args, **kwargs):
|
|
112
|
+
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
113
|
+
return args[0] if args else kwargs.get(Const.INPUT)
|
|
114
|
+
return self.api_func(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
api_register = None
|
|
118
|
+
stub_tensor_set = False
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def get_api_register(return_new=False):
|
|
122
|
+
global stub_tensor_set
|
|
123
|
+
|
|
124
|
+
def stub_method(method):
|
|
125
|
+
def wrapped_method(*args, **kwargs):
|
|
126
|
+
return method(*args, **kwargs)
|
|
127
|
+
return wrapped_method
|
|
128
|
+
if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set:
|
|
129
|
+
api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, [])
|
|
130
|
+
for attr_name in dir(StubTensor):
|
|
131
|
+
attr = getattr(StubTensor, attr_name)
|
|
132
|
+
if attr_name in api_names and callable(attr):
|
|
133
|
+
setattr(StubTensor, attr_name, stub_method(attr))
|
|
134
|
+
stub_tensor_set = True
|
|
135
|
+
|
|
136
|
+
if return_new:
|
|
137
|
+
return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
138
|
+
|
|
139
|
+
global api_register
|
|
140
|
+
if api_register is None:
|
|
141
|
+
api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
142
|
+
return api_register
|
|
@@ -28,23 +28,22 @@ def get_cell_count(name):
|
|
|
28
28
|
return HOOKCell.cell_count[name]
|
|
29
29
|
|
|
30
30
|
|
|
31
|
-
def __init__(self,
|
|
31
|
+
def __init__(self, hook_build_func) -> None:
|
|
32
32
|
super(HOOKCell, self).__init__()
|
|
33
33
|
self.changed_status = False
|
|
34
34
|
self.input_kwargs = {}
|
|
35
|
-
self.prefix = ""
|
|
36
35
|
if not HOOKCell.g_stop_hook:
|
|
37
36
|
HOOKCell.g_stop_hook = True
|
|
38
37
|
self.changed_status = True
|
|
39
|
-
if hasattr(self, "prefix_api_name"):
|
|
40
|
-
self.prefix = self.prefix_api_name
|
|
41
|
-
|
|
42
38
|
self.forward_data_collected = False
|
|
43
|
-
|
|
44
|
-
self.
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
39
|
+
|
|
40
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
41
|
+
if callable(hook_build_func):
|
|
42
|
+
forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = hook_build_func(prefix)
|
|
43
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
44
|
+
self.register_forward_hook(forward_hook)
|
|
45
|
+
register_backward_hook_functions["full"](self, backward_hook)
|
|
46
|
+
register_backward_hook_functions["pre"](self, backward_pre_hook)
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
# 重载call,加全局标志。
|