mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -14,22 +14,25 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
|
-
import os
|
|
18
17
|
import inspect
|
|
18
|
+
import os
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
import torch.distributed as dist
|
|
22
22
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
24
25
|
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
25
26
|
from msprobe.pytorch.common.log import logger
|
|
26
27
|
from msprobe.pytorch.common.utils import (
|
|
27
|
-
torch_without_guard_version,
|
|
28
|
+
torch_without_guard_version,
|
|
29
|
+
is_gpu,
|
|
30
|
+
torch_device_guard,
|
|
31
|
+
parameter_adapter
|
|
28
32
|
)
|
|
29
33
|
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
30
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
31
35
|
from msprobe.pytorch.hook_module.utils import dynamic_import_op
|
|
32
|
-
from msprobe.core.common.file_utils import load_yaml
|
|
33
36
|
|
|
34
37
|
try:
|
|
35
38
|
import mindspeed.ops
|
|
@@ -38,42 +41,46 @@ except ImportError:
|
|
|
38
41
|
else:
|
|
39
42
|
mindspeed_enable = True
|
|
40
43
|
|
|
41
|
-
|
|
42
44
|
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
43
45
|
|
|
44
46
|
_inner_used_api = {}
|
|
45
47
|
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
46
48
|
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
49
|
+
dist_data_collect_func = {}
|
|
50
|
+
dist_batch_data_collect_func = []
|
|
47
51
|
|
|
48
52
|
_api_types = {
|
|
49
53
|
Const.PT_FRAMEWORK: {
|
|
50
|
-
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
51
|
-
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
52
|
-
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
53
|
-
Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
|
|
54
|
-
Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
|
|
54
|
+
Const.PT_API_TYPE_FUNCTIONAL: ((torch.nn.functional,), (torch.nn.functional,)),
|
|
55
|
+
Const.PT_API_TYPE_TENSOR: ((torch.Tensor,), (torch.Tensor,)),
|
|
56
|
+
Const.PT_API_TYPE_TORCH: ((torch,), (torch,)),
|
|
57
|
+
Const.PT_API_TYPE_VF: ((torch._C._VariableFunctionsClass,), (torch._VF,)),
|
|
58
|
+
Const.PT_API_TYPE_DIST: ((dist,), (dist, dist.distributed_c10d))
|
|
55
59
|
}
|
|
56
60
|
}
|
|
57
61
|
if not is_gpu:
|
|
58
62
|
import torch_npu
|
|
63
|
+
|
|
59
64
|
if torch_without_guard_version:
|
|
60
65
|
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
61
66
|
{
|
|
62
|
-
Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
|
|
67
|
+
Const.PT_API_TYPE_NPU: ((torch.ops.npu, torch_npu), (torch_npu, torch.ops.npu)),
|
|
63
68
|
}
|
|
64
69
|
)
|
|
65
70
|
else:
|
|
66
71
|
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
67
|
-
{Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
|
|
72
|
+
{Const.PT_API_TYPE_NPU: ((torch_npu._C._VariableFunctionsClass,), (torch_npu,))}
|
|
68
73
|
)
|
|
69
74
|
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
70
75
|
{
|
|
71
|
-
Const.PT_API_TYPE_NPU_DIST: (
|
|
72
|
-
|
|
76
|
+
Const.PT_API_TYPE_NPU_DIST: (
|
|
77
|
+
(torch_npu.distributed,),
|
|
78
|
+
(torch_npu.distributed, torch_npu.distributed.distributed_c10d)
|
|
79
|
+
)
|
|
73
80
|
}
|
|
74
81
|
)
|
|
75
82
|
if mindspeed_enable:
|
|
76
|
-
_api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
|
|
83
|
+
_api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: ((mindspeed.ops,), (mindspeed.ops,))})
|
|
77
84
|
mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
|
|
78
85
|
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
79
86
|
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
@@ -94,16 +101,48 @@ def dist_module_forward(module, *args, **kwargs):
|
|
|
94
101
|
use_async_op_flag = False
|
|
95
102
|
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
96
103
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
+
def create_async_callback_func(catch_func):
|
|
105
|
+
full_name = module.full_forward_name if hasattr(module, "full_forward_name") else None
|
|
106
|
+
|
|
107
|
+
def store_data():
|
|
108
|
+
catch_func(module, full_name, args, kwargs, handle)
|
|
109
|
+
|
|
110
|
+
return store_data
|
|
111
|
+
|
|
112
|
+
if use_async_op_flag or module.api_name in ['isend', 'irecv']:
|
|
113
|
+
dist_data_collect_func[handle] = create_async_callback_func(module.distributed_forward_hook)
|
|
114
|
+
if module.api_name == 'batch_isend_irecv':
|
|
115
|
+
dist_batch_data_collect_func.append([handle, create_async_callback_func(module.distributed_forward_hook)])
|
|
104
116
|
return handle
|
|
105
117
|
|
|
106
118
|
|
|
119
|
+
def redirect_wait():
|
|
120
|
+
if hasattr(dist, "Work"):
|
|
121
|
+
from torch.distributed import Work
|
|
122
|
+
else:
|
|
123
|
+
from torch._C._distributed_c10d import Work
|
|
124
|
+
origin_wait = Work.wait
|
|
125
|
+
|
|
126
|
+
def wrapped_wait(work):
|
|
127
|
+
def wrapped_wait(*args, **kwargs):
|
|
128
|
+
origin_wait(*args, **kwargs)
|
|
129
|
+
if args[0] in dist_data_collect_func:
|
|
130
|
+
store_func = dist_data_collect_func.pop(args[0])
|
|
131
|
+
store_func()
|
|
132
|
+
return
|
|
133
|
+
for value in dist_batch_data_collect_func:
|
|
134
|
+
if args[0] in value[0]:
|
|
135
|
+
value[0].remove(args[0])
|
|
136
|
+
if len(value[0]) == 0:
|
|
137
|
+
store_func = value[1]
|
|
138
|
+
store_func()
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
return wrapped_wait
|
|
142
|
+
|
|
143
|
+
Work.wait = wrapped_wait(Work)
|
|
144
|
+
|
|
145
|
+
|
|
107
146
|
def npu_module_forward(module, *args, **kwargs):
|
|
108
147
|
if not module.need_hook:
|
|
109
148
|
if module.api_name not in npu_custom_functions:
|
|
@@ -125,15 +164,14 @@ forward_methods = {
|
|
|
125
164
|
class ApiTemplate(HOOKModule):
|
|
126
165
|
def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
|
|
127
166
|
self.api_name = api_name
|
|
128
|
-
self.api_func = api_func
|
|
129
167
|
self.prefix = prefix
|
|
130
168
|
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
131
169
|
self.need_hook = need_hook
|
|
132
170
|
self.device = device
|
|
171
|
+
self.op_is_distributed = prefix == Const.DIST_API_TYPE_PREFIX
|
|
133
172
|
if self.need_hook:
|
|
134
173
|
super().__init__(hook_build_func)
|
|
135
|
-
|
|
136
|
-
self.op_is_distributed = True
|
|
174
|
+
self.api_func = api_func
|
|
137
175
|
|
|
138
176
|
@torch_device_guard
|
|
139
177
|
def forward(self, *args, **kwargs):
|
|
@@ -14,50 +14,30 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
|
-
import threading
|
|
18
17
|
from collections import defaultdict
|
|
19
18
|
|
|
20
19
|
import torch
|
|
21
20
|
import torch.nn as nn
|
|
22
21
|
import torch.utils.hooks as full_hooks
|
|
23
22
|
|
|
24
|
-
from msprobe.
|
|
25
|
-
from msprobe.core.common.utils import ThreadSafe
|
|
26
|
-
from msprobe.pytorch.common.utils import register_forward_pre_hook, register_forward_hook
|
|
23
|
+
from msprobe.pytorch.common.utils import register_forward_pre_hook
|
|
27
24
|
|
|
28
25
|
|
|
29
26
|
class HOOKModule(nn.Module):
|
|
30
27
|
module_count = defaultdict(int)
|
|
31
|
-
inner_stop_hook = defaultdict(bool)
|
|
32
28
|
|
|
33
29
|
def __init__(self, hook_build_func) -> None:
|
|
34
30
|
super(HOOKModule, self).__init__()
|
|
35
|
-
self.
|
|
36
|
-
self.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
if not Runtime.is_running:
|
|
43
|
-
return
|
|
44
|
-
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
45
|
-
ThreadSafe.acquire()
|
|
46
|
-
if callable(hook_build_func):
|
|
47
|
-
hook_set = hook_build_func(prefix)
|
|
48
|
-
register_forward_pre_hook(self, hook_set.forward_pre_hook)
|
|
49
|
-
register_forward_hook(self, hook_set.forward_hook)
|
|
50
|
-
self.register_backward_hook(hook_set.backward_hook)
|
|
31
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
32
|
+
op_is_distributed = self.op_is_distributed if hasattr(self, "op_is_distributed") else False
|
|
33
|
+
if callable(hook_build_func):
|
|
34
|
+
hook_set = hook_build_func(prefix)
|
|
35
|
+
register_forward_pre_hook(self, hook_set.forward_pre_hook)
|
|
36
|
+
if op_is_distributed:
|
|
37
|
+
self.distributed_forward_hook = hook_set.distributed_forward_hook
|
|
51
38
|
|
|
52
39
|
def __call__(self, *args, **kwargs):
|
|
53
|
-
|
|
54
|
-
if not self.stop_hook:
|
|
55
|
-
HOOKModule.inner_stop_hook[self.tid] = True
|
|
56
|
-
changed = True
|
|
57
|
-
result = self._call_func(*args, **kwargs)
|
|
58
|
-
if changed:
|
|
59
|
-
HOOKModule.inner_stop_hook[self.tid] = False
|
|
60
|
-
return result
|
|
40
|
+
return self._call_func(*args, **kwargs)
|
|
61
41
|
|
|
62
42
|
@staticmethod
|
|
63
43
|
def reset_module_stats():
|
|
@@ -13,13 +13,18 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
import functools
|
|
17
|
+
import threading
|
|
17
18
|
from contextlib import nullcontext
|
|
18
19
|
|
|
20
|
+
import torch
|
|
21
|
+
|
|
19
22
|
from msprobe.core.common.const import Const
|
|
20
|
-
from msprobe.core.common.
|
|
23
|
+
from msprobe.core.common.runtime import Runtime
|
|
24
|
+
from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe
|
|
25
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleForwardInputsOutputs)
|
|
21
26
|
from msprobe.core.hook_manager import BaseHookManager, HookSet
|
|
22
|
-
from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2
|
|
27
|
+
from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2, register_forward_hook
|
|
23
28
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
29
|
|
|
25
30
|
|
|
@@ -37,23 +42,65 @@ class PytorchHookManager(BaseHookManager):
|
|
|
37
42
|
HOOKModule.add_module_count(name)
|
|
38
43
|
|
|
39
44
|
@staticmethod
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
45
|
+
def _get_count(name):
|
|
46
|
+
return HOOKModule.get_module_count(name)
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _process_kwargs_and_output(module, tid, hook_type, kwargs_or_output, output_or_kwargs):
|
|
50
|
+
if hook_type == Const.API:
|
|
51
|
+
kwargs = kwargs_or_output
|
|
52
|
+
output = output_or_kwargs
|
|
53
|
+
else:
|
|
54
|
+
kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
|
|
55
|
+
output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
43
56
|
return kwargs, output
|
|
44
57
|
|
|
45
58
|
def build_hook(self, hook_type, name):
|
|
46
59
|
if hook_type == Const.API:
|
|
47
|
-
|
|
60
|
+
hook_set = HookSet(
|
|
61
|
+
forward_pre_hook=self._build_forward_pre_hook(hook_type, name),
|
|
62
|
+
distributed_forward_hook=self._build_distributed_forward_hook()
|
|
63
|
+
)
|
|
48
64
|
else:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
65
|
+
full_backward_name = replace_last_occurrence(name, Const.FORWARD, Const.BACKWARD)
|
|
66
|
+
hook_set = HookSet(
|
|
67
|
+
forward_hook=self._build_forward_hook(hook_type, name),
|
|
68
|
+
backward_hook=self._build_backward_hook(hook_type, full_backward_name)
|
|
69
|
+
)
|
|
70
|
+
return hook_set
|
|
71
|
+
|
|
72
|
+
def _register_forward_hook(self, module, api_name):
|
|
73
|
+
if not hasattr(module, 'msprobe_forward_hook'):
|
|
74
|
+
register_forward_hook(module, self._build_forward_hook(Const.API, api_name))
|
|
75
|
+
setattr(module, 'msprobe_forward_hook', True)
|
|
76
|
+
|
|
77
|
+
def _register_backward_hook(self, module, full_backward_name, args):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
def _register_backward_pre_hook(self, module, full_backward_name, output):
|
|
81
|
+
var = output
|
|
82
|
+
while not isinstance(var, torch.Tensor):
|
|
83
|
+
if isinstance(var, dict):
|
|
84
|
+
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
|
|
85
|
+
elif isinstance(var, (list, tuple)):
|
|
86
|
+
if var:
|
|
87
|
+
var = var[0]
|
|
88
|
+
else:
|
|
89
|
+
return output
|
|
90
|
+
else:
|
|
91
|
+
return output
|
|
92
|
+
|
|
93
|
+
if not (var.requires_grad and torch.is_grad_enabled()):
|
|
94
|
+
return output
|
|
95
|
+
|
|
96
|
+
grad_fn = var.grad_fn
|
|
97
|
+
if grad_fn is not None:
|
|
98
|
+
backward_hook = self._build_backward_hook(Const.API, full_backward_name)
|
|
99
|
+
wrapper = functools.partial(backward_hook, module)
|
|
100
|
+
functools.update_wrapper(wrapper, backward_hook)
|
|
101
|
+
grad_fn.register_hook(wrapper)
|
|
102
|
+
|
|
103
|
+
return output
|
|
57
104
|
|
|
58
105
|
def _need_exchange(self, module):
|
|
59
106
|
return True
|
|
@@ -66,3 +113,25 @@ class PytorchHookManager(BaseHookManager):
|
|
|
66
113
|
for key, value in module.named_parameters(recurse=False)
|
|
67
114
|
}
|
|
68
115
|
return params_dict
|
|
116
|
+
|
|
117
|
+
def _build_distributed_forward_hook(self):
|
|
118
|
+
def distributed_forward_hook(module, full_name, args, kwargs, output):
|
|
119
|
+
if not full_name or not Runtime.is_running:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
tid = threading.get_ident()
|
|
123
|
+
with ThreadSafe():
|
|
124
|
+
BaseHookManager.inner_switch[tid] = True
|
|
125
|
+
self.data_collector.update_api_or_module_name(full_name)
|
|
126
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
127
|
+
with self._no_grad_context():
|
|
128
|
+
self.data_collector.forward_output_data_collect(
|
|
129
|
+
full_name,
|
|
130
|
+
module,
|
|
131
|
+
self._pid,
|
|
132
|
+
module_input_output,
|
|
133
|
+
self._is_recompute
|
|
134
|
+
)
|
|
135
|
+
BaseHookManager.inner_switch[tid] = False
|
|
136
|
+
|
|
137
|
+
return distributed_forward_hook
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import importlib
|
|
18
|
+
import types
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.pytorch.common.utils import torch_version_above_or_equal_2
|
|
24
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
25
|
+
|
|
26
|
+
if torch_version_above_or_equal_2:
|
|
27
|
+
from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def wrap_jit_script_func():
|
|
31
|
+
def patched_script(*args, **kwargs):
|
|
32
|
+
all_api_registered = api_register.all_api_registered
|
|
33
|
+
if all_api_registered:
|
|
34
|
+
api_register.restore_all_api()
|
|
35
|
+
result = original_script(*args, **kwargs)
|
|
36
|
+
if all_api_registered:
|
|
37
|
+
api_register.register_all_api()
|
|
38
|
+
return result
|
|
39
|
+
|
|
40
|
+
original_script = torch.jit.script
|
|
41
|
+
api_register = get_api_register()
|
|
42
|
+
torch.jit.script = patched_script
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def wrap_compile_script_func():
|
|
46
|
+
def _patched_convert_frame(compiler_fn, hooks):
|
|
47
|
+
"""
|
|
48
|
+
在调用原 convert_frame 生成的 _convert_frame 之前恢复 API,
|
|
49
|
+
调用完之后再重新注册所有 API。
|
|
50
|
+
"""
|
|
51
|
+
# 拿到原来 inner 版的 _convert_frame
|
|
52
|
+
inner_convert = _orig_convert_frame(compiler_fn, hooks)
|
|
53
|
+
|
|
54
|
+
def _wrapped(frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state):
|
|
55
|
+
reg = get_api_register()
|
|
56
|
+
# 进入前 restore
|
|
57
|
+
reg.restore_all_api()
|
|
58
|
+
try:
|
|
59
|
+
result = inner_convert(frame, cache_size, hooks, frame_state)
|
|
60
|
+
except Exception:
|
|
61
|
+
# 异常时也要确保 register
|
|
62
|
+
reg.register_all_api()
|
|
63
|
+
raise
|
|
64
|
+
# 正常结束后 register
|
|
65
|
+
reg.register_all_api()
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
# 保留原属性以兼容
|
|
69
|
+
_wrapped._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
|
|
70
|
+
_wrapped._clone_with_backend = lambda backend: _patched_convert_frame(backend,
|
|
71
|
+
hooks) # type: ignore[attr-defined]
|
|
72
|
+
return _wrapped
|
|
73
|
+
|
|
74
|
+
import torch._dynamo.convert_frame as _cf_mod
|
|
75
|
+
_cf_mod.convert_frame = _patched_convert_frame
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def patch_dynamo_compile():
|
|
79
|
+
cf = importlib.import_module("torch._dynamo.convert_frame")
|
|
80
|
+
if not hasattr(cf, "_compile"):
|
|
81
|
+
logger.warning("No found torch._dynamo.convert_frame._compile")
|
|
82
|
+
|
|
83
|
+
original = cf._compile
|
|
84
|
+
if getattr(original, "__msprobe_patched__", False):
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
@functools.wraps(original)
|
|
88
|
+
def wrapped(*args, **kwargs):
|
|
89
|
+
result = None
|
|
90
|
+
try:
|
|
91
|
+
reg = get_api_register()
|
|
92
|
+
reg.restore_all_api()
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.warning(f"[msprobe] Pre restore_all_api failed: {e}")
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
result = original(*args, **kwargs)
|
|
99
|
+
except Exception:
|
|
100
|
+
logger.warning("[msprobe] _compile execution failed (returning None)")
|
|
101
|
+
result = None
|
|
102
|
+
finally:
|
|
103
|
+
try:
|
|
104
|
+
reg = get_api_register()
|
|
105
|
+
reg.register_all_api() # 改成注册hook
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.warning(f"[msprobe] Post register_all_api failed: {e}")
|
|
108
|
+
return result
|
|
109
|
+
wrapped.__msprobe_patched__ = True
|
|
110
|
+
wrapped.__msprobe_original__ = original
|
|
111
|
+
cf._compile = wrapped
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def unpatch_dynamo_compile() -> bool:
|
|
115
|
+
# 预留取消patch接口
|
|
116
|
+
cf = importlib.import_module("torch._dynamo.convert_frame")
|
|
117
|
+
current = getattr(cf, "_compile", None)
|
|
118
|
+
if current is None:
|
|
119
|
+
return False
|
|
120
|
+
original = getattr(current, "__msprobe_original__", None)
|
|
121
|
+
if original is None:
|
|
122
|
+
return False
|
|
123
|
+
cf._compile = original
|
|
124
|
+
return True
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def preprocess_func():
|
|
128
|
+
try:
|
|
129
|
+
from torch.utils._device import _device_constructors
|
|
130
|
+
_device_constructors()
|
|
131
|
+
except ImportError:
|
|
132
|
+
pass
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.warning(f"Failed to execute _device_constructors. Error Details: {str(e)}")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def wrap_script_func():
|
|
138
|
+
wrap_jit_script_func()
|
|
139
|
+
if torch_version_above_or_equal_2:
|
|
140
|
+
patch_dynamo_compile()
|
|
@@ -1260,6 +1260,12 @@ torch_npu:
|
|
|
1260
1260
|
- npu_scatter_nd_update
|
|
1261
1261
|
- npu_prefetch
|
|
1262
1262
|
- npu_dynamic_block_quant
|
|
1263
|
+
- npu_add_rms_norm
|
|
1264
|
+
- _npu_flash_attention
|
|
1265
|
+
- _npu_rotary_embedding
|
|
1266
|
+
- _npu_reshape_and_cache
|
|
1267
|
+
- _npu_paged_attention
|
|
1268
|
+
- npu_moe_gating_top_k
|
|
1263
1269
|
|
|
1264
1270
|
aten:
|
|
1265
1271
|
- signbit
|
|
@@ -79,7 +79,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
|
|
|
79
79
|
for op, value in ops.items():
|
|
80
80
|
tag = f"{vpp_name}/{op}"
|
|
81
81
|
writer.add_scalar(tag, value, step)
|
|
82
|
-
writer.
|
|
82
|
+
writer.close()
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
@recursion_depth_decorator("update_dict", max_depth=50)
|
|
@@ -111,3 +111,97 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val):
|
|
|
111
111
|
@torch.no_grad()
|
|
112
112
|
def get_nans(t):
|
|
113
113
|
return torch.isnan(t).sum()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def check_tensor_dim(tensor, n):
|
|
117
|
+
"""检查张量维度是否大于n
|
|
118
|
+
"""
|
|
119
|
+
if not isinstance(tensor, torch.Tensor):
|
|
120
|
+
raise TypeError(
|
|
121
|
+
f"Input must be a PyTorch tensor. Got {type(tensor)} instead. "
|
|
122
|
+
f"Consider using torch.tensor() for conversion."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if tensor.dim() < n:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Tensor must have at least {n} dimensions. "
|
|
128
|
+
f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@torch.no_grad()
|
|
133
|
+
def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3):
|
|
134
|
+
input_tensor = input_tensor.float()
|
|
135
|
+
try:
|
|
136
|
+
check_tensor_dim(input_tensor, 2)
|
|
137
|
+
except (TypeError, ValueError) as e:
|
|
138
|
+
logger.warning(f"Calculate max eigenvalue failed: {e}")
|
|
139
|
+
return torch.tensor(0)
|
|
140
|
+
in_features = input_tensor.shape[1]
|
|
141
|
+
u_tensor = torch.randn(in_features).to(input_tensor.device)
|
|
142
|
+
u_norm = u_tensor.norm()
|
|
143
|
+
if u_norm.item() == 0:
|
|
144
|
+
return torch.tensor(0)
|
|
145
|
+
u_tensor = u_tensor / u_tensor.norm()
|
|
146
|
+
input_seq = torch.matmul(input_tensor.T, input_tensor)
|
|
147
|
+
for _ in range(num_iterations):
|
|
148
|
+
v_tensor = torch.matmul(input_seq, u_tensor)
|
|
149
|
+
spectral_norm = torch.matmul(v_tensor.T, u_tensor)
|
|
150
|
+
v_norm = v_tensor.norm()
|
|
151
|
+
if v_norm > 0:
|
|
152
|
+
u_tensor = v_tensor / v_norm
|
|
153
|
+
else:
|
|
154
|
+
spectral_norm = torch.tensor(0)
|
|
155
|
+
break
|
|
156
|
+
return spectral_norm.sqrt()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@torch.no_grad()
|
|
160
|
+
def cal_entropy(qk_tensor, mask=None):
|
|
161
|
+
try:
|
|
162
|
+
check_tensor_dim(qk_tensor, 2)
|
|
163
|
+
except (TypeError, ValueError) as e:
|
|
164
|
+
logger.warning(f"Calculate max eigenvalue failed: {e}")
|
|
165
|
+
return torch.tensor(0), torch.tensor(0)
|
|
166
|
+
if mask is None:
|
|
167
|
+
mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to(
|
|
168
|
+
qk_tensor.device)
|
|
169
|
+
qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True)
|
|
170
|
+
qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf'))
|
|
171
|
+
softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1)
|
|
172
|
+
# softmax取QK矩阵最大值
|
|
173
|
+
softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1))
|
|
174
|
+
entropy = torch.mean(-torch.nansum(softmax_qkt *
|
|
175
|
+
torch.log(softmax_qkt), dim=1))
|
|
176
|
+
return entropy, softmax_max
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@torch.no_grad()
|
|
180
|
+
def cal_qkt(q_h, k_h, order="s,b,h,d"):
|
|
181
|
+
# q_h shape is [s, b, h, d]
|
|
182
|
+
try:
|
|
183
|
+
check_tensor_dim(q_h, 4)
|
|
184
|
+
check_tensor_dim(k_h, 4)
|
|
185
|
+
except (TypeError, ValueError) as e:
|
|
186
|
+
logger.warning(f"Calculate qk tensor failed: {e}")
|
|
187
|
+
return torch.tensor(0)
|
|
188
|
+
|
|
189
|
+
if order == "s,b,h,d":
|
|
190
|
+
qkt = torch.matmul(
|
|
191
|
+
q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5
|
|
192
|
+
elif order == "b,s,h,d":
|
|
193
|
+
qkt = torch.matmul(
|
|
194
|
+
q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5
|
|
195
|
+
else:
|
|
196
|
+
logger.warning("Calculate qk tensor failed: Order unsupported.")
|
|
197
|
+
qkt = torch.tensor(0)
|
|
198
|
+
return qkt
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@torch.no_grad()
|
|
202
|
+
def cal_stable_rank(weight: torch.Tensor):
|
|
203
|
+
eig = max_eigenvalue(weight)
|
|
204
|
+
if eig == torch.tensor(0):
|
|
205
|
+
return torch.tensor(0), torch.tensor(0)
|
|
206
|
+
f_norm = torch.norm(weight, p="fro")
|
|
207
|
+
return f_norm / eig, eig
|