mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,7 +21,9 @@ from msprobe.core.common.exceptions import MsprobeException
|
|
|
21
21
|
from msprobe.core.common.file_utils import FileChecker
|
|
22
22
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.common.utils import check_save_param
|
|
24
25
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
26
|
+
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
25
27
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
28
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
29
|
from msprobe.pytorch.service import Service
|
|
@@ -49,7 +51,7 @@ class PrecisionDebugger:
|
|
|
49
51
|
dump_path=None,
|
|
50
52
|
level=None,
|
|
51
53
|
model=None,
|
|
52
|
-
step=None
|
|
54
|
+
step=None
|
|
53
55
|
):
|
|
54
56
|
if not hasattr(self, "initialized"):
|
|
55
57
|
config_params = ConfigParameters(config_path,
|
|
@@ -59,7 +61,6 @@ class PrecisionDebugger:
|
|
|
59
61
|
model)
|
|
60
62
|
self.check_input_params(config_params)
|
|
61
63
|
|
|
62
|
-
self.api_origin = False
|
|
63
64
|
self.initialized = True
|
|
64
65
|
self.model = model
|
|
65
66
|
common_config, task_config = parse_json_config(config_path, task)
|
|
@@ -67,12 +68,13 @@ class PrecisionDebugger:
|
|
|
67
68
|
if self.task == Const.GRAD_PROBE:
|
|
68
69
|
self.gm = GradientMonitor(common_config, task_config)
|
|
69
70
|
return
|
|
70
|
-
if step:
|
|
71
|
+
if step is not None:
|
|
71
72
|
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
72
73
|
self.config = DebuggerConfig(
|
|
73
74
|
common_config, task_config, task, dump_path, level
|
|
74
75
|
)
|
|
75
76
|
self.service = Service(self.config)
|
|
77
|
+
self.module_dumper = ModuleDumper(self.service)
|
|
76
78
|
self.enable_dataloader = self.config.enable_dataloader
|
|
77
79
|
if self.enable_dataloader:
|
|
78
80
|
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
@@ -105,9 +107,11 @@ class PrecisionDebugger:
|
|
|
105
107
|
raise MsprobeException(
|
|
106
108
|
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
109
|
|
|
108
|
-
if args.model is not None
|
|
109
|
-
|
|
110
|
-
|
|
110
|
+
if args.model is not None:
|
|
111
|
+
logger.warning_on_rank_0(
|
|
112
|
+
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
113
|
+
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
114
|
+
)
|
|
111
115
|
|
|
112
116
|
@classmethod
|
|
113
117
|
def start(cls, model=None):
|
|
@@ -120,15 +124,12 @@ class PrecisionDebugger:
|
|
|
120
124
|
if instance.enable_dataloader:
|
|
121
125
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
122
126
|
else:
|
|
123
|
-
instance.service.start(instance.model
|
|
124
|
-
instance.api_origin = False
|
|
127
|
+
instance.service.start(instance.model)
|
|
125
128
|
|
|
126
|
-
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
127
129
|
@classmethod
|
|
128
130
|
def forward_backward_dump_end(cls):
|
|
129
131
|
instance = cls._instance
|
|
130
|
-
instance.
|
|
131
|
-
instance.api_origin = True
|
|
132
|
+
instance.stop()
|
|
132
133
|
|
|
133
134
|
@classmethod
|
|
134
135
|
def stop(cls):
|
|
@@ -158,6 +159,49 @@ class PrecisionDebugger:
|
|
|
158
159
|
return
|
|
159
160
|
cls._instance.gm.monitor(model)
|
|
160
161
|
|
|
162
|
+
@classmethod
|
|
163
|
+
def save(cls, variable, name, save_backward=True):
|
|
164
|
+
instance = cls._instance
|
|
165
|
+
if not instance:
|
|
166
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
167
|
+
if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG:
|
|
168
|
+
return
|
|
169
|
+
try:
|
|
170
|
+
check_save_param(variable, name, save_backward)
|
|
171
|
+
except ValueError:
|
|
172
|
+
return
|
|
173
|
+
instance.service.save(variable, name, save_backward)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def module_dump(module, dump_name):
|
|
177
|
+
if not isinstance(module, torch.nn.Module):
|
|
178
|
+
raise MsprobeException(
|
|
179
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
180
|
+
f"the module argument in module_dump must be a torch.nn.Module subclass"
|
|
181
|
+
)
|
|
182
|
+
if not isinstance(dump_name, str):
|
|
183
|
+
raise MsprobeException(
|
|
184
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
185
|
+
f"the dump_name argument in module_dump must be a str type"
|
|
186
|
+
)
|
|
187
|
+
instance = PrecisionDebugger._instance
|
|
188
|
+
if not instance:
|
|
189
|
+
raise MsprobeException(
|
|
190
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
191
|
+
f"PrecisionDebugger must be instantiated before using module_dump interface"
|
|
192
|
+
)
|
|
193
|
+
instance.module_dumper.start_module_dump(module, dump_name)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def module_dump_end():
|
|
197
|
+
instance = PrecisionDebugger._instance
|
|
198
|
+
if not instance:
|
|
199
|
+
raise MsprobeException(
|
|
200
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
201
|
+
f"PrecisionDebugger must be instantiated before using module_dump_end interface"
|
|
202
|
+
)
|
|
203
|
+
instance.module_dumper.stop_module_dump()
|
|
204
|
+
|
|
161
205
|
|
|
162
206
|
def iter_tracer(func):
|
|
163
207
|
def func_wrapper(*args, **kwargs):
|
|
File without changes
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
19
|
+
from msprobe.pytorch.common.log import logger
|
|
20
|
+
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
21
|
+
|
|
22
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModuleDumper:
|
|
26
|
+
def __init__(self, service):
|
|
27
|
+
self.service = service
|
|
28
|
+
self.hook_handle_list = []
|
|
29
|
+
|
|
30
|
+
def start_module_dump(self, module, dump_name):
|
|
31
|
+
api_register.api_originality()
|
|
32
|
+
self.register_hook(module, dump_name)
|
|
33
|
+
|
|
34
|
+
def stop_module_dump(self):
|
|
35
|
+
api_register.api_modularity()
|
|
36
|
+
for hook_handle in self.hook_handle_list:
|
|
37
|
+
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
|
+
hook_handle.remove()
|
|
39
|
+
self.hook_handle_list.clear()
|
|
40
|
+
|
|
41
|
+
def register_hook(self, module, dump_name):
|
|
42
|
+
prefix_name = (
|
|
43
|
+
BaseScope.Module_Type_Module + Const.SEP +
|
|
44
|
+
dump_name + Const.SEP +
|
|
45
|
+
module.__class__.__name__ + Const.SEP
|
|
46
|
+
)
|
|
47
|
+
module_processor = self.service.module_processor
|
|
48
|
+
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
|
|
49
|
+
BaseScope.Module_Type_Module,
|
|
50
|
+
prefix_name
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if module_processor.has_register_backward_hook(module):
|
|
54
|
+
logger.warning(
|
|
55
|
+
f"The {dump_name} module has registered deprecated register_backward_hook,"
|
|
56
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
57
|
+
)
|
|
58
|
+
if torch_version_above_or_equal_2:
|
|
59
|
+
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
60
|
+
else:
|
|
61
|
+
if not module_processor.has_register_backward_hook(module):
|
|
62
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
63
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
64
|
+
)
|
|
65
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
66
|
+
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
67
|
+
self.hook_handle_list.append(forward_hook_handle)
|
|
68
|
+
if not module_processor.has_register_backward_hook(module):
|
|
69
|
+
backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
70
|
+
self.hook_handle_list.append(backward_hook_handle)
|
|
71
|
+
|
|
72
|
+
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
73
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
|
|
74
|
+
)
|
|
75
|
+
forward_hook_handle = module.register_forward_hook(
|
|
76
|
+
module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
|
|
77
|
+
)
|
|
78
|
+
self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
|
|
79
|
+
if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
|
|
80
|
+
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
81
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
|
|
82
|
+
)
|
|
83
|
+
backward_hook_handle = module.register_full_backward_hook(
|
|
84
|
+
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
85
|
+
)
|
|
86
|
+
self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from functools import wraps
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
|
+
from msprobe.pytorch.common.log import logger
|
|
22
|
+
from msprobe.pytorch.common.utils import replace_last_occurrence
|
|
23
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint
|
|
24
|
+
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
25
|
+
from torch.utils.hooks import BackwardHook
|
|
26
|
+
|
|
27
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def checkpoint_without_early_stop(*args, **kwargs):
|
|
31
|
+
with set_checkpoint_early_stop(False):
|
|
32
|
+
return origin_checkpoint(*args, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def replace_checkpoint():
|
|
36
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModuleProcesser:
|
|
40
|
+
module_count = {}
|
|
41
|
+
module_stack = []
|
|
42
|
+
api_parent_node = ""
|
|
43
|
+
module_node = {}
|
|
44
|
+
|
|
45
|
+
def __init__(self, scope):
|
|
46
|
+
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
47
|
+
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
48
|
+
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
49
|
+
replace_checkpoint()
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def clone_return_value(func):
|
|
53
|
+
@wraps(func)
|
|
54
|
+
def clone_return_value_func(*args, **kwargs):
|
|
55
|
+
result = func(*args, **kwargs)
|
|
56
|
+
return ModuleProcesser.clone_if_tensor(result)
|
|
57
|
+
|
|
58
|
+
return clone_return_value_func
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def clone_if_tensor(result):
|
|
62
|
+
if isinstance(result, torch.Tensor):
|
|
63
|
+
return result.clone()
|
|
64
|
+
elif type(result) is tuple:
|
|
65
|
+
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
66
|
+
elif type(result) is list:
|
|
67
|
+
return list(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
68
|
+
elif type(result) is dict:
|
|
69
|
+
return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
|
|
70
|
+
else:
|
|
71
|
+
return result
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def module_count_func(module_name):
|
|
75
|
+
if module_name not in ModuleProcesser.module_count:
|
|
76
|
+
ModuleProcesser.module_count[module_name] = 0
|
|
77
|
+
else:
|
|
78
|
+
ModuleProcesser.module_count[module_name] += 1
|
|
79
|
+
return ModuleProcesser.module_count[module_name]
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def has_register_backward_hook(module):
|
|
83
|
+
return hasattr(module, '_backward_hooks') and \
|
|
84
|
+
len(module._backward_hooks) > 0 and \
|
|
85
|
+
module._is_full_backward_hook is False
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def get_modules_and_names(models):
|
|
89
|
+
modules_and_names_with_index = {}
|
|
90
|
+
if isinstance(models, (list, tuple)):
|
|
91
|
+
for index, model in enumerate(models):
|
|
92
|
+
modules_and_names_with_index[str(index)] = model.named_modules()
|
|
93
|
+
else:
|
|
94
|
+
modules_and_names_with_index["-1"] = models.named_modules()
|
|
95
|
+
return modules_and_names_with_index
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def reset_module_stats(cls):
|
|
99
|
+
cls.module_count = {}
|
|
100
|
+
cls.module_stack = []
|
|
101
|
+
cls.api_parent_node = ""
|
|
102
|
+
cls.module_node = {}
|
|
103
|
+
|
|
104
|
+
def register_module_hook(self, models, build_hook):
|
|
105
|
+
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
106
|
+
modules_and_names_with_index = self.get_modules_and_names(models)
|
|
107
|
+
for index, modules_and_names in modules_and_names_with_index.items():
|
|
108
|
+
model = models if index == "-1" else models[int(index)]
|
|
109
|
+
for name, module in modules_and_names:
|
|
110
|
+
if module == model:
|
|
111
|
+
continue
|
|
112
|
+
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
113
|
+
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
114
|
+
name + Const.SEP + module.__class__.__name__ + Const.SEP)
|
|
115
|
+
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
|
|
116
|
+
BaseScope.Module_Type_Module,
|
|
117
|
+
prefix_name
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if self.has_register_backward_hook(module):
|
|
121
|
+
logger.warning(
|
|
122
|
+
f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
|
|
123
|
+
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
124
|
+
)
|
|
125
|
+
if torch_version_above_or_equal_2:
|
|
126
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
127
|
+
else:
|
|
128
|
+
if not self.has_register_backward_hook(module):
|
|
129
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
130
|
+
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
131
|
+
if not self.has_register_backward_hook(module):
|
|
132
|
+
module.register_full_backward_hook(backward_hook)
|
|
133
|
+
|
|
134
|
+
module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START))
|
|
135
|
+
module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP))
|
|
136
|
+
if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module):
|
|
137
|
+
module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START))
|
|
138
|
+
module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP))
|
|
139
|
+
|
|
140
|
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
141
|
+
|
|
142
|
+
def pre_hook(module, input, output=None):
|
|
143
|
+
try:
|
|
144
|
+
index = ModuleProcesser.module_count_func(name_prefix)
|
|
145
|
+
except IndexError as e:
|
|
146
|
+
index = None
|
|
147
|
+
pass
|
|
148
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
149
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
150
|
+
module.mindstudio_reserved_name = []
|
|
151
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
152
|
+
if self.module_stack:
|
|
153
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
154
|
+
else:
|
|
155
|
+
ModuleProcesser.module_node[full_name] = None
|
|
156
|
+
|
|
157
|
+
ModuleProcesser.module_stack.append(full_name)
|
|
158
|
+
if self.module_stack:
|
|
159
|
+
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
160
|
+
if self.scope:
|
|
161
|
+
self.scope.begin_module(full_name)
|
|
162
|
+
|
|
163
|
+
def end_hook(module, input, output=None):
|
|
164
|
+
if self.module_stack:
|
|
165
|
+
ModuleProcesser.module_stack.pop()
|
|
166
|
+
if self.module_stack:
|
|
167
|
+
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
168
|
+
else:
|
|
169
|
+
ModuleProcesser.api_parent_node = None
|
|
170
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
171
|
+
raise RuntimeError(f"module reserve name is None when pop")
|
|
172
|
+
current_name = module.mindstudio_reserved_name.pop()
|
|
173
|
+
if self.scope:
|
|
174
|
+
self.scope.end_module(current_name)
|
|
175
|
+
|
|
176
|
+
def backward_hook(module, input, output=None):
|
|
177
|
+
try:
|
|
178
|
+
index = ModuleProcesser.module_count_func(name_prefix)
|
|
179
|
+
except IndexError as e:
|
|
180
|
+
index = None
|
|
181
|
+
pass
|
|
182
|
+
full_name = name_prefix + Const.SEP + str(index)
|
|
183
|
+
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
184
|
+
module.mindstudio_reserved_name = []
|
|
185
|
+
module.mindstudio_reserved_name.append(full_name)
|
|
186
|
+
forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
|
|
187
|
+
ModuleProcesser.module_node[full_name] = replace_last_occurrence(
|
|
188
|
+
ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
|
|
189
|
+
ModuleProcesser.api_parent_node = None
|
|
190
|
+
if self.scope:
|
|
191
|
+
self.scope.begin_module(full_name)
|
|
192
|
+
|
|
193
|
+
if torch_version_above_or_equal_2:
|
|
194
|
+
if Const.START in start_or_stop:
|
|
195
|
+
return pre_hook
|
|
196
|
+
else:
|
|
197
|
+
return end_hook
|
|
198
|
+
else:
|
|
199
|
+
if Const.FORWARD in name_prefix and Const.START in start_or_stop:
|
|
200
|
+
return pre_hook
|
|
201
|
+
elif Const.BACKWARD in name_prefix:
|
|
202
|
+
return backward_hook
|
|
203
|
+
else:
|
|
204
|
+
return end_hook
|
|
@@ -39,7 +39,6 @@ class DataParams:
|
|
|
39
39
|
origin_func: Optional[Callable] = None
|
|
40
40
|
api_type: Optional[str] = None
|
|
41
41
|
fuzz_stage: Optional[str] = None
|
|
42
|
-
grad_unequal_flag: Optional[bool] = True
|
|
43
42
|
|
|
44
43
|
|
|
45
44
|
@dataclass
|
|
@@ -127,6 +126,8 @@ def make_unequal_row(
|
|
|
127
126
|
)
|
|
128
127
|
if isinstance(ratio, float):
|
|
129
128
|
row.max_rel = ratio - 1
|
|
129
|
+
if isinstance(ratio, str):
|
|
130
|
+
row.max_rel = ratio
|
|
130
131
|
origin_tensor = data_params.original_result
|
|
131
132
|
perturbed_tensor = data_params.perturbed_result
|
|
132
133
|
if index is not None:
|
|
@@ -124,6 +124,7 @@ class TorchC:
|
|
|
124
124
|
abs = torch._C._VariableFunctionsClass.abs
|
|
125
125
|
where = torch._C._VariableFunctionsClass.where
|
|
126
126
|
div = torch._C._VariableFunctionsClass.div
|
|
127
|
+
mul = torch._C._VariableFunctionsClass.mul
|
|
127
128
|
max = torch._C._VariableFunctionsClass.max
|
|
128
129
|
min = torch._C._VariableFunctionsClass.min
|
|
129
130
|
gt = torch._C._VariableFunctionsClass.gt
|
|
@@ -138,3 +139,5 @@ class TorchC:
|
|
|
138
139
|
tensor_split = torch._C._VariableFunctionsClass.tensor_split
|
|
139
140
|
stack = torch._C._VariableFunctionsClass.stack
|
|
140
141
|
reshape = torch._C._VariableFunctionsClass.reshape
|
|
142
|
+
nan_to_num = torch._C._VariableFunctionsClass.nan_to_num
|
|
143
|
+
aminmax = torch._C._VariableFunctionsClass.aminmax
|
|
@@ -82,13 +82,11 @@ class GradSaver:
|
|
|
82
82
|
data_params = DataParams()
|
|
83
83
|
data_params.original_result = origin_grad
|
|
84
84
|
data_params.perturbed_result = perturbed_grad
|
|
85
|
-
data_params.grad_unequal_flag = False
|
|
86
85
|
data_params.valid_input_index = index
|
|
87
86
|
try:
|
|
88
87
|
handler.handle(data_params)
|
|
89
88
|
if not data_params.is_consistent:
|
|
90
89
|
self.is_compare = False
|
|
91
|
-
data_params.grad_unequal_flag = True
|
|
92
90
|
data_params.is_consistent = True
|
|
93
91
|
data_params.perturbed_result = self.perturbed_grad_input
|
|
94
92
|
data_params.original_result = self.origin_grad_input
|
|
@@ -89,12 +89,6 @@ class FuzzHandler(ABC):
|
|
|
89
89
|
)
|
|
90
90
|
return origin_output_chunks, perturbed_output_chunks
|
|
91
91
|
|
|
92
|
-
@staticmethod
|
|
93
|
-
def convert_overflow_ratio_to_consistent(ratio):
|
|
94
|
-
if math.isnan(ratio) or math.isinf(ratio):
|
|
95
|
-
return ThresholdConfig.COMP_CONSISTENT
|
|
96
|
-
return ratio
|
|
97
|
-
|
|
98
92
|
@abstractmethod
|
|
99
93
|
def get_threshold(self, dtype):
|
|
100
94
|
pass
|
|
@@ -107,10 +101,10 @@ class FuzzHandler(ABC):
|
|
|
107
101
|
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
108
102
|
):
|
|
109
103
|
if norm_type == NormType.ENDLESS_NORM:
|
|
110
|
-
return self.
|
|
104
|
+
return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
|
|
111
105
|
return ThresholdConfig.COMP_CONSISTENT
|
|
112
106
|
|
|
113
|
-
def
|
|
107
|
+
def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
|
|
114
108
|
origin_output_chunks, perturbed_output_chunks = (
|
|
115
109
|
self.tensor_split_for_error_calculate(origin_output, perturbed_output)
|
|
116
110
|
)
|
|
@@ -122,42 +116,30 @@ class FuzzHandler(ABC):
|
|
|
122
116
|
raise FreeBenchmarkException(
|
|
123
117
|
FreeBenchmarkException.OutputIndexError, err_msg
|
|
124
118
|
)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
norm3 = np.inf
|
|
119
|
+
|
|
120
|
+
max_ratio = ThresholdConfig.COMP_CONSISTENT
|
|
128
121
|
for i, chunk_origin in enumerate(origin_output_chunks):
|
|
129
122
|
if chunk_origin.nelement() == 0:
|
|
130
123
|
break
|
|
131
124
|
chunk_perturbed = perturbed_output_chunks[i]
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
TorchC.
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
TorchC.div(
|
|
143
|
-
|
|
144
|
-
TorchC.clamp(chunk_origin, min=abs_tol),
|
|
145
|
-
),
|
|
146
|
-
1,
|
|
125
|
+
# 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
|
|
126
|
+
if TorchC.lt(
|
|
127
|
+
TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
|
|
128
|
+
):
|
|
129
|
+
return ThresholdConfig.SYMBOL_FLIPPING
|
|
130
|
+
# 求A/B B/A的比值前,将值限制在大于极小值范围内
|
|
131
|
+
clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
|
|
132
|
+
clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
|
|
133
|
+
# 对于计算结果为nan的情况,认为两者没有差异
|
|
134
|
+
ratio_tensor = TorchC.nan_to_num(
|
|
135
|
+
TorchC.div(clamp_origin, clamp_perturbed),
|
|
136
|
+
nan=ThresholdConfig.COMP_CONSISTENT,
|
|
147
137
|
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
|
|
154
|
-
norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
155
|
-
|
|
156
|
-
if norm3 < 0:
|
|
157
|
-
ratio = ThresholdConfig.SYMBOL_FLIPPING
|
|
158
|
-
else:
|
|
159
|
-
ratio = max(norm1, norm2)
|
|
160
|
-
return ratio
|
|
138
|
+
# 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
|
|
139
|
+
min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
|
|
140
|
+
min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
|
|
141
|
+
max_ratio = max(max_ratio, min_ratio_reciprocal)
|
|
142
|
+
return max_ratio
|
|
161
143
|
|
|
162
144
|
def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
|
|
163
145
|
try:
|
|
@@ -220,10 +202,12 @@ class FuzzHandler(ABC):
|
|
|
220
202
|
)
|
|
221
203
|
npu_consistent = is_consistent
|
|
222
204
|
max_fuzz_ratio = (
|
|
223
|
-
max_fuzz_ratio
|
|
205
|
+
max_fuzz_ratio
|
|
206
|
+
if not isinstance(ratio, (int, float))
|
|
207
|
+
else max(max_fuzz_ratio, ratio)
|
|
224
208
|
)
|
|
225
|
-
data_params.is_consistent = is_consistent
|
|
226
|
-
if not is_consistent
|
|
209
|
+
data_params.is_consistent = is_consistent
|
|
210
|
+
if not is_consistent:
|
|
227
211
|
self.unequal_rows.append(
|
|
228
212
|
make_unequal_row(data_params, self.params, ratio=ratio)
|
|
229
213
|
)
|
|
@@ -235,12 +219,12 @@ class FuzzHandler(ABC):
|
|
|
235
219
|
)
|
|
236
220
|
npu_consistent = npu_consistent and is_consistent
|
|
237
221
|
max_fuzz_ratio = (
|
|
238
|
-
max_fuzz_ratio
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
is_consistent and data_params.is_consistent
|
|
222
|
+
max_fuzz_ratio
|
|
223
|
+
if not isinstance(ratio, (int, float))
|
|
224
|
+
else max(max_fuzz_ratio, ratio)
|
|
242
225
|
)
|
|
243
|
-
|
|
226
|
+
data_params.is_consistent = is_consistent
|
|
227
|
+
if not is_consistent:
|
|
244
228
|
self.unequal_rows.append(
|
|
245
229
|
make_unequal_row(
|
|
246
230
|
data_params, self.params, ratio=ratio, index=index_
|
|
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
|
|
|
75
75
|
if self.params.preheat_config.get("preheat_step") <= self.params.step:
|
|
76
76
|
return data_params.original_result
|
|
77
77
|
|
|
78
|
-
if not data_params.grad_unequal_flag:
|
|
79
|
-
data_params.grad_unequal_flag = True
|
|
80
|
-
data_params.is_consistent = False
|
|
81
|
-
return data_params.original_result
|
|
82
78
|
preheat_counter.add_api_called_time(self.pure_name)
|
|
83
79
|
|
|
84
80
|
if not self._is_take_a_sample():
|
|
@@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar
|
|
|
27
27
|
from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
|
|
28
28
|
npu_scaled_masked_softmax_backward
|
|
29
29
|
from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
|
|
30
|
+
from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam
|
|
31
|
+
from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu
|
|
32
|
+
from msprobe.pytorch.bench_functions.mish import npu_mish
|
|
33
|
+
from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax
|
|
34
|
+
from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2
|
|
30
35
|
from msprobe.pytorch.common.utils import logger
|
|
31
36
|
|
|
32
37
|
|
|
@@ -79,7 +84,8 @@ class Register(dict):
|
|
|
79
84
|
npu_custom_functions = Register()
|
|
80
85
|
npu_custom_functions([
|
|
81
86
|
npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
|
|
82
|
-
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
|
|
87
|
+
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam,
|
|
88
|
+
npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2
|
|
83
89
|
])
|
|
84
90
|
|
|
85
91
|
# register for npu custom backward bench functions
|