mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -19,8 +19,9 @@ import torch
|
|
|
19
19
|
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
21
|
from msprobe.core.common.file_utils import FileChecker
|
|
22
|
-
from msprobe.core.common.utils import get_real_step_or_rank
|
|
22
|
+
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.common.utils import check_save_param
|
|
24
25
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
25
26
|
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
26
27
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
@@ -158,6 +159,28 @@ class PrecisionDebugger:
|
|
|
158
159
|
return
|
|
159
160
|
cls._instance.gm.monitor(model)
|
|
160
161
|
|
|
162
|
+
@classmethod
|
|
163
|
+
def save(cls, variable, name, save_backward=True):
|
|
164
|
+
instance = cls._instance
|
|
165
|
+
if not instance:
|
|
166
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
167
|
+
if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG:
|
|
168
|
+
return
|
|
169
|
+
try:
|
|
170
|
+
check_save_param(variable, name, save_backward)
|
|
171
|
+
except ValueError:
|
|
172
|
+
return
|
|
173
|
+
instance.service.save(variable, name, save_backward)
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def set_init_step(cls, step):
|
|
177
|
+
instance = cls._instance
|
|
178
|
+
if not instance:
|
|
179
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
180
|
+
check_init_step(step)
|
|
181
|
+
instance.service.init_step = step
|
|
182
|
+
instance.service.loop = 0
|
|
183
|
+
|
|
161
184
|
|
|
162
185
|
def module_dump(module, dump_name):
|
|
163
186
|
if not isinstance(module, torch.nn.Module):
|
|
@@ -17,7 +17,7 @@ import torch
|
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
18
|
from msprobe.core.data_dump.scope import BaseScope
|
|
19
19
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.hook_module.
|
|
20
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
21
21
|
|
|
22
22
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
23
|
|
|
@@ -26,13 +26,14 @@ class ModuleDumper:
|
|
|
26
26
|
def __init__(self, service):
|
|
27
27
|
self.service = service
|
|
28
28
|
self.hook_handle_list = []
|
|
29
|
+
self.api_register = get_api_register()
|
|
29
30
|
|
|
30
31
|
def start_module_dump(self, module, dump_name):
|
|
31
|
-
api_register.
|
|
32
|
+
self.api_register.restore_all_api()
|
|
32
33
|
self.register_hook(module, dump_name)
|
|
33
34
|
|
|
34
35
|
def stop_module_dump(self):
|
|
35
|
-
api_register.
|
|
36
|
+
self.api_register.register_all_api()
|
|
36
37
|
for hook_handle in self.hook_handle_list:
|
|
37
38
|
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
39
|
hook_handle.remove()
|
|
@@ -16,14 +16,17 @@
|
|
|
16
16
|
from functools import wraps
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
+
from torch.utils.hooks import BackwardHook
|
|
20
|
+
|
|
19
21
|
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
23
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
24
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from
|
|
23
|
-
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
24
|
-
from torch.utils.hooks import BackwardHook
|
|
25
|
+
from msprobe.pytorch.common.utils import replace_last_occurrence, is_float8_tensor
|
|
25
26
|
|
|
26
27
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
def checkpoint_without_early_stop(*args, **kwargs):
|
|
@@ -32,7 +35,8 @@ def checkpoint_without_early_stop(*args, **kwargs):
|
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
def replace_checkpoint():
|
|
35
|
-
|
|
38
|
+
if torch_version_above_or_equal_2:
|
|
39
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
36
40
|
|
|
37
41
|
|
|
38
42
|
class ModuleProcesser:
|
|
@@ -45,29 +49,8 @@ class ModuleProcesser:
|
|
|
45
49
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
46
50
|
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
|
|
47
51
|
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
48
|
-
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
|
|
49
52
|
replace_checkpoint()
|
|
50
53
|
|
|
51
|
-
@staticmethod
|
|
52
|
-
def filter_tensor_and_tuple(func):
|
|
53
|
-
@wraps(func)
|
|
54
|
-
def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
|
|
55
|
-
# setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
|
|
56
|
-
# setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
|
|
57
|
-
if not isinstance(args[1], (torch.Tensor, tuple)):
|
|
58
|
-
for item_str in dir(args[1]):
|
|
59
|
-
item = getattr(args[1], item_str)
|
|
60
|
-
# 处理tensor或者只包含tensor的元组
|
|
61
|
-
if isinstance(item, torch.Tensor) or \
|
|
62
|
-
(isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
|
|
63
|
-
args_new = (args[0], item)
|
|
64
|
-
result = func(*args_new, **kwargs)
|
|
65
|
-
setattr(args[1], item_str, result)
|
|
66
|
-
return args[1]
|
|
67
|
-
return func(*args, **kwargs)
|
|
68
|
-
|
|
69
|
-
return wrap_by_filter_tensor_and_tuple
|
|
70
|
-
|
|
71
54
|
@staticmethod
|
|
72
55
|
def clone_return_value(func):
|
|
73
56
|
@wraps(func)
|
|
@@ -78,14 +61,15 @@ class ModuleProcesser:
|
|
|
78
61
|
return clone_return_value_func
|
|
79
62
|
|
|
80
63
|
@staticmethod
|
|
64
|
+
@recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
81
65
|
def clone_if_tensor(result):
|
|
82
|
-
if isinstance(result, torch.Tensor):
|
|
66
|
+
if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
|
|
83
67
|
return result.clone()
|
|
84
|
-
elif
|
|
68
|
+
elif type(result) is tuple:
|
|
85
69
|
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
86
|
-
elif
|
|
70
|
+
elif type(result) is list:
|
|
87
71
|
return list(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
88
|
-
elif
|
|
72
|
+
elif type(result) is dict:
|
|
89
73
|
return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
|
|
90
74
|
else:
|
|
91
75
|
return result
|
|
@@ -103,7 +87,7 @@ class ModuleProcesser:
|
|
|
103
87
|
return hasattr(module, '_backward_hooks') and \
|
|
104
88
|
len(module._backward_hooks) > 0 and \
|
|
105
89
|
module._is_full_backward_hook is False
|
|
106
|
-
|
|
90
|
+
|
|
107
91
|
@staticmethod
|
|
108
92
|
def get_modules_and_names(models):
|
|
109
93
|
modules_and_names_with_index = {}
|
|
@@ -129,9 +113,11 @@ class ModuleProcesser:
|
|
|
129
113
|
for name, module in modules_and_names:
|
|
130
114
|
if module == model:
|
|
131
115
|
continue
|
|
116
|
+
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
117
|
+
continue
|
|
132
118
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
133
|
-
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
134
|
-
|
|
119
|
+
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
120
|
+
name + Const.SEP + module.__class__.__name__ + Const.SEP)
|
|
135
121
|
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook(
|
|
136
122
|
BaseScope.Module_Type_Module,
|
|
137
123
|
prefix_name
|
|
@@ -203,9 +189,9 @@ class ModuleProcesser:
|
|
|
203
189
|
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
204
190
|
module.mindstudio_reserved_name = []
|
|
205
191
|
module.mindstudio_reserved_name.append(full_name)
|
|
206
|
-
forward_full_name = full_name
|
|
207
|
-
ModuleProcesser.module_node[full_name] =
|
|
208
|
-
Const.FORWARD, Const.BACKWARD)
|
|
192
|
+
forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
|
|
193
|
+
ModuleProcesser.module_node[full_name] = replace_last_occurrence(
|
|
194
|
+
ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
|
|
209
195
|
ModuleProcesser.api_parent_node = None
|
|
210
196
|
if self.scope:
|
|
211
197
|
self.scope.begin_module(full_name)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
21
21
|
|
|
22
22
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import math
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
22
22
|
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
|
|
|
95
95
|
except Exception:
|
|
96
96
|
logger.warning_on_rank_0(
|
|
97
97
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
98
|
-
f"when
|
|
98
|
+
f"when calculating the maximum value, the tensor is changed to float32."
|
|
99
99
|
)
|
|
100
100
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
101
101
|
if max_val < abs_tol:
|
|
102
102
|
logger.warning_on_rank_0(
|
|
103
103
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
104
|
-
f"
|
|
104
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
105
105
|
)
|
|
106
106
|
return False
|
|
107
107
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
100
100
|
except Exception:
|
|
101
101
|
logger.warning_on_rank_0(
|
|
102
102
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
103
|
-
f"when calculate
|
|
103
|
+
f"when calculate the maximum value, the tensor is changed to float32."
|
|
104
104
|
)
|
|
105
105
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
106
106
|
if max_val < abs_tol:
|
|
107
107
|
logger.warning_on_rank_0(
|
|
108
108
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
109
|
-
f"
|
|
109
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
110
110
|
)
|
|
111
111
|
return False
|
|
112
112
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.common.
|
|
18
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
19
19
|
from msprobe.pytorch.free_benchmark import logger
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
|
|
|
49
49
|
except Exception as e:
|
|
50
50
|
logger.warning_on_rank_0(
|
|
51
51
|
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
52
|
-
f"when
|
|
52
|
+
f"when comparing the results, an exception is raised: {e}"
|
|
53
53
|
)
|
|
54
54
|
return data_params.original_result
|
|
@@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar
|
|
|
27
27
|
from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
|
|
28
28
|
npu_scaled_masked_softmax_backward
|
|
29
29
|
from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward
|
|
30
|
+
from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam
|
|
31
|
+
from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu
|
|
32
|
+
from msprobe.pytorch.bench_functions.mish import npu_mish
|
|
33
|
+
from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax
|
|
34
|
+
from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2
|
|
30
35
|
from msprobe.pytorch.common.utils import logger
|
|
31
36
|
|
|
32
37
|
|
|
@@ -65,7 +70,7 @@ class Register(dict):
|
|
|
65
70
|
|
|
66
71
|
def add_register_item(key, value):
|
|
67
72
|
if key in self._dict:
|
|
68
|
-
logger.warning(f"{value.__name__} has been registered before, so we will
|
|
73
|
+
logger.warning(f"{value.__name__} has been registered before, so we will override it.")
|
|
69
74
|
self[key] = value
|
|
70
75
|
return value
|
|
71
76
|
|
|
@@ -79,7 +84,8 @@ class Register(dict):
|
|
|
79
84
|
npu_custom_functions = Register()
|
|
80
85
|
npu_custom_functions([
|
|
81
86
|
npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
|
|
82
|
-
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention
|
|
87
|
+
npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam,
|
|
88
|
+
npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2
|
|
83
89
|
])
|
|
84
90
|
|
|
85
91
|
# register for npu custom backward bench functions
|
|
@@ -46,7 +46,7 @@ class GradientMonitor:
|
|
|
46
46
|
if not os.path.exists(self._output_path):
|
|
47
47
|
create_directory(self._output_path)
|
|
48
48
|
else:
|
|
49
|
-
logger.warning(f"the file in {self._output_path} will be
|
|
49
|
+
logger.warning(f"the file in {self._output_path} will be deleted")
|
|
50
50
|
self._step = -1
|
|
51
51
|
self._param2name = defaultdict(str)
|
|
52
52
|
|
|
@@ -97,7 +97,7 @@ class GradientMonitor:
|
|
|
97
97
|
create_directory(output_dirpath)
|
|
98
98
|
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
99
99
|
if os.path.exists(output_path):
|
|
100
|
-
logger.warning(f"{output_path} will be
|
|
100
|
+
logger.warning(f"{output_path} will be deleted")
|
|
101
101
|
remove_path(output_path)
|
|
102
102
|
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
103
103
|
output_lines.insert(0, header_result)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
24
|
+
from msprobe.pytorch.common.utils import (
|
|
25
|
+
torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
|
|
26
|
+
)
|
|
27
|
+
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
28
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
32
|
+
|
|
33
|
+
_api_types = {
|
|
34
|
+
Const.PT_FRAMEWORK: {
|
|
35
|
+
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
36
|
+
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
37
|
+
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
38
|
+
Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
|
|
39
|
+
Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
if not is_gpu:
|
|
43
|
+
import torch_npu
|
|
44
|
+
if torch_without_guard_version:
|
|
45
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
46
|
+
{
|
|
47
|
+
Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
52
|
+
{Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
|
|
53
|
+
)
|
|
54
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
55
|
+
{
|
|
56
|
+
Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
|
|
57
|
+
torch_npu.distributed.distributed_c10d))
|
|
58
|
+
}
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
_inner_used_api = {}
|
|
62
|
+
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
63
|
+
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@parameter_adapter
|
|
67
|
+
def tensor_module_forward(module, *args, **kwargs):
|
|
68
|
+
return module.api_func(*args, **kwargs)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def dist_module_forward(module, *args, **kwargs):
|
|
72
|
+
handle = module.api_func(*args, **kwargs)
|
|
73
|
+
if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]:
|
|
74
|
+
if handle and hasattr(handle, 'wait'):
|
|
75
|
+
handle.wait()
|
|
76
|
+
if module.api_name == "batch_isend_irecv":
|
|
77
|
+
if isinstance(handle, list):
|
|
78
|
+
for req in handle:
|
|
79
|
+
req.wait()
|
|
80
|
+
return handle
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def npu_module_forward(module, *args, **kwargs):
|
|
84
|
+
if not module.need_hook:
|
|
85
|
+
if module.api_name not in npu_custom_functions:
|
|
86
|
+
raise Exception(f'There is not bench function {module.api_name}')
|
|
87
|
+
if module.device == Const.CUDA_LOWERCASE:
|
|
88
|
+
module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
|
|
89
|
+
if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
|
|
90
|
+
return npu_custom_functions[module.api_name](*args, **kwargs)
|
|
91
|
+
return module.api_func(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
forward_methods = {
|
|
95
|
+
"Tensor": tensor_module_forward,
|
|
96
|
+
"Distributed": dist_module_forward,
|
|
97
|
+
"NPU": npu_module_forward
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ApiTemplate(HOOKModule):
|
|
102
|
+
def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
|
|
103
|
+
self.api_name = api_name
|
|
104
|
+
self.api_func = api_func
|
|
105
|
+
self.prefix = prefix
|
|
106
|
+
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
107
|
+
self.need_hook = need_hook
|
|
108
|
+
self.device = device
|
|
109
|
+
if self.need_hook:
|
|
110
|
+
super().__init__(hook_build_func)
|
|
111
|
+
if prefix == Const.DIST_API_TYPE_PREFIX:
|
|
112
|
+
self.op_is_distributed = True
|
|
113
|
+
|
|
114
|
+
@torch_device_guard
|
|
115
|
+
def forward(self, *args, **kwargs):
|
|
116
|
+
exec_func = forward_methods.get(self.prefix)
|
|
117
|
+
exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
|
|
118
|
+
return exec_func(*args, **kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
api_register = None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_api_register(return_new=False):
|
|
125
|
+
if return_new:
|
|
126
|
+
return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
127
|
+
|
|
128
|
+
global api_register
|
|
129
|
+
if api_register is None:
|
|
130
|
+
api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
131
|
+
return api_register
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,6 +21,8 @@ import torch
|
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
import torch.utils.hooks as full_hooks
|
|
23
23
|
|
|
24
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
25
|
+
|
|
24
26
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
25
27
|
|
|
26
28
|
|
|
@@ -28,28 +30,27 @@ class HOOKModule(nn.Module):
|
|
|
28
30
|
module_count = defaultdict(int)
|
|
29
31
|
inner_stop_hook = {}
|
|
30
32
|
|
|
31
|
-
def __init__(self,
|
|
33
|
+
def __init__(self, hook_build_func) -> None:
|
|
32
34
|
super(HOOKModule, self).__init__()
|
|
33
35
|
self.has_overflow = False
|
|
34
|
-
self.prefix = ""
|
|
35
36
|
self.current_thread = threading.current_thread().ident
|
|
36
37
|
if self.current_thread not in HOOKModule.inner_stop_hook:
|
|
37
38
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
38
39
|
self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
|
|
39
40
|
|
|
40
41
|
if not self.stop_hook:
|
|
41
|
-
if hasattr(self, "prefix_op_name_"):
|
|
42
|
-
self.prefix = self.prefix_op_name_
|
|
43
|
-
|
|
44
42
|
self.forward_data_collected = False
|
|
45
|
-
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
43
|
+
|
|
44
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
45
|
+
if callable(hook_build_func):
|
|
46
|
+
forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix)
|
|
47
|
+
if torch_version_above_or_equal_2:
|
|
48
|
+
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
49
|
+
self.register_forward_hook(forward_hook, with_kwargs=True)
|
|
50
|
+
else:
|
|
51
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
52
|
+
self.register_forward_hook(forward_hook)
|
|
53
|
+
self.register_backward_hook(backward_hook)
|
|
53
54
|
|
|
54
55
|
def __call__(self, *args, **kwargs):
|
|
55
56
|
changed = False
|
|
@@ -111,6 +112,10 @@ class HOOKModule(nn.Module):
|
|
|
111
112
|
return result
|
|
112
113
|
else:
|
|
113
114
|
return result
|
|
115
|
+
|
|
116
|
+
if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()):
|
|
117
|
+
return result
|
|
118
|
+
|
|
114
119
|
grad_fn = var.grad_fn
|
|
115
120
|
if grad_fn is not None:
|
|
116
121
|
for hook in non_full_backward_hooks:
|
|
@@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector):
|
|
|
32
32
|
def patch_clip_grad(func):
|
|
33
33
|
def wrapper(*args, **kwargs):
|
|
34
34
|
data_collector.optimizer_status = Const.CLIP_GRAD
|
|
35
|
-
func(*args, **kwargs)
|
|
35
|
+
result = func(*args, **kwargs)
|
|
36
36
|
data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
|
|
37
|
+
return result
|
|
37
38
|
|
|
38
39
|
return wrapper
|
|
39
40
|
|