mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,205 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
from mindspore import Tensor, ops, mint
|
|
17
|
-
from mindspore.mint.nn import functional
|
|
18
|
-
from mindspore.common._stub_tensor import StubTensor
|
|
19
|
-
from mindspore.communication import comm_func
|
|
20
|
-
|
|
21
|
-
from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
|
|
22
|
-
HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
|
|
23
|
-
HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
|
|
24
|
-
HOOKTorchDistributedOP, HOOKTorchNpuOP,
|
|
25
|
-
get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
|
|
26
|
-
from msprobe.core.common.utils import Const
|
|
27
|
-
from msprobe.mindspore.common.utils import is_mindtorch
|
|
28
|
-
|
|
29
|
-
if is_mindtorch():
|
|
30
|
-
import torch
|
|
31
|
-
import torch_npu
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def stub_method(method):
|
|
35
|
-
def wrapped_method(*args, **kwargs):
|
|
36
|
-
return method(*args, **kwargs)
|
|
37
|
-
return wrapped_method
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
class ApiRegistry:
|
|
41
|
-
def __init__(self):
|
|
42
|
-
self.tensor_ori_attr = {}
|
|
43
|
-
self.stub_tensor_ori_attr = {}
|
|
44
|
-
self.functional_ori_attr = {}
|
|
45
|
-
self.mint_ops_ori_attr = {}
|
|
46
|
-
self.mint_func_ops_ori_attr = {}
|
|
47
|
-
self.distributed_ori_attr = {}
|
|
48
|
-
self.norm_inner_ops_ori_attr = {}
|
|
49
|
-
|
|
50
|
-
self.torch_ori_attr = {}
|
|
51
|
-
self.torch_tensor_ori_attr = {}
|
|
52
|
-
self.torch_functional_ori_attr = {}
|
|
53
|
-
self.torch_distributed_ori_attr = {}
|
|
54
|
-
self.torch_npu_ori_attr = {}
|
|
55
|
-
|
|
56
|
-
self.tensor_hook_attr = {}
|
|
57
|
-
self.stub_tensor_hook_attr = {}
|
|
58
|
-
self.functional_hook_attr = {}
|
|
59
|
-
self.mint_ops_hook_attr = {}
|
|
60
|
-
self.mint_func_ops_hook_attr = {}
|
|
61
|
-
self.distibuted_hook_attr = {}
|
|
62
|
-
self.norm_inner_ops_hook_attr = {}
|
|
63
|
-
|
|
64
|
-
self.torch_hook_attr = {}
|
|
65
|
-
self.torch_tensor_hook_attr = {}
|
|
66
|
-
self.torch_functional_hook_attr = {}
|
|
67
|
-
self.torch_distributed_hook_attr = {}
|
|
68
|
-
self.torch_npu_hook_attr = {}
|
|
69
|
-
|
|
70
|
-
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
71
|
-
|
|
72
|
-
@staticmethod
|
|
73
|
-
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
74
|
-
for api in api_list:
|
|
75
|
-
if Const.SEP in api:
|
|
76
|
-
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
77
|
-
sub_module = getattr(ori_api_group, sub_module_name)
|
|
78
|
-
ori_api_func = getattr(sub_module, sub_op)
|
|
79
|
-
else:
|
|
80
|
-
ori_api_func = getattr(ori_api_group, api)
|
|
81
|
-
if ori_api_group == StubTensor:
|
|
82
|
-
api_ori_attr[api] = stub_method(ori_api_func)
|
|
83
|
-
continue
|
|
84
|
-
api_ori_attr[api] = ori_api_func
|
|
85
|
-
|
|
86
|
-
@staticmethod
|
|
87
|
-
def set_api_attr(api_group, attr_dict):
|
|
88
|
-
for api, api_attr in attr_dict.items():
|
|
89
|
-
if Const.SEP in api:
|
|
90
|
-
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
91
|
-
sub_module = getattr(api_group, sub_module_name, None)
|
|
92
|
-
if sub_module is not None:
|
|
93
|
-
setattr(sub_module, sub_op, api_attr)
|
|
94
|
-
else:
|
|
95
|
-
setattr(api_group, api, api_attr)
|
|
96
|
-
|
|
97
|
-
def norm_inner_op_set_hook_func(self):
|
|
98
|
-
self.set_api_attr(ops, self.norm_inner_ops_hook_attr)
|
|
99
|
-
|
|
100
|
-
def norm_inner_op_set_ori_func(self):
|
|
101
|
-
self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
|
|
102
|
-
|
|
103
|
-
def api_set_hook_func(self):
|
|
104
|
-
if is_mindtorch():
|
|
105
|
-
self.set_api_attr(torch, self.torch_hook_attr)
|
|
106
|
-
self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
|
|
107
|
-
self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
|
|
108
|
-
self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
|
|
109
|
-
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
110
|
-
else:
|
|
111
|
-
self.set_api_attr(Tensor, self.tensor_hook_attr)
|
|
112
|
-
self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
|
|
113
|
-
self.set_api_attr(ops, self.functional_hook_attr)
|
|
114
|
-
self.set_api_attr(mint, self.mint_ops_hook_attr)
|
|
115
|
-
self.set_api_attr(functional, self.mint_func_ops_hook_attr)
|
|
116
|
-
self.set_api_attr(comm_func, self.distibuted_hook_attr)
|
|
117
|
-
|
|
118
|
-
def api_set_ori_func(self):
|
|
119
|
-
if is_mindtorch():
|
|
120
|
-
self.set_api_attr(torch, self.torch_ori_attr)
|
|
121
|
-
self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
|
|
122
|
-
self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
|
|
123
|
-
self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
|
|
124
|
-
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
125
|
-
else:
|
|
126
|
-
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
127
|
-
self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
|
|
128
|
-
self.set_api_attr(ops, self.functional_ori_attr)
|
|
129
|
-
self.set_api_attr(mint, self.mint_ops_ori_attr)
|
|
130
|
-
self.set_api_attr(functional, self.mint_func_ops_ori_attr)
|
|
131
|
-
self.set_api_attr(comm_func, self.distributed_ori_attr)
|
|
132
|
-
|
|
133
|
-
def initialize_hook(self, hook):
|
|
134
|
-
setup_hooks(hook)
|
|
135
|
-
if is_mindtorch():
|
|
136
|
-
wrap_torch_api_name = get_wrap_torch_api_list()
|
|
137
|
-
self.store_ori_attr(torch,
|
|
138
|
-
wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
|
|
139
|
-
self.store_ori_attr(torch.Tensor,
|
|
140
|
-
wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
|
|
141
|
-
self.store_ori_attr(torch.nn.functional,
|
|
142
|
-
wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
|
|
143
|
-
self.store_ori_attr(torch.distributed,
|
|
144
|
-
wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
|
|
145
|
-
self.store_ori_attr(torch_npu,
|
|
146
|
-
wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
|
|
147
|
-
for attr_name in dir(HOOKTorchOP):
|
|
148
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
149
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
150
|
-
self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
|
|
151
|
-
for attr_name in dir(HOOKTorchTensor):
|
|
152
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
153
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
154
|
-
self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
|
|
155
|
-
for attr_name in dir(HOOKTorchFunctionalOP):
|
|
156
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
157
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
158
|
-
self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
|
|
159
|
-
for attr_name in dir(HOOKTorchDistributedOP):
|
|
160
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
161
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
162
|
-
self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
|
|
163
|
-
for attr_name in dir(HOOKTorchNpuOP):
|
|
164
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
165
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
166
|
-
self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
|
|
167
|
-
return
|
|
168
|
-
|
|
169
|
-
wrap_api_name = get_wrap_api_list()
|
|
170
|
-
self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
|
|
171
|
-
self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
|
|
172
|
-
self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
|
|
173
|
-
self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
|
|
174
|
-
self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
|
|
175
|
-
self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
|
|
176
|
-
self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
177
|
-
for attr_name in dir(HOOKTensor):
|
|
178
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
179
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
180
|
-
self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name)
|
|
181
|
-
for attr_name in dir(HOOKStubTensor):
|
|
182
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
183
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
184
|
-
self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name)
|
|
185
|
-
for attr_name in dir(HOOKFunctionalOP):
|
|
186
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
187
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
188
|
-
self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
|
|
189
|
-
if api_name in self.norm_inner_ops:
|
|
190
|
-
self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name)
|
|
191
|
-
for attr_name in dir(HOOKMintOP):
|
|
192
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
193
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
194
|
-
self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name)
|
|
195
|
-
for attr_name in dir(HOOKMintNNFunctionalOP):
|
|
196
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
197
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
198
|
-
self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
|
|
199
|
-
for attr_name in dir(HOOKDistributedOP):
|
|
200
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
201
|
-
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
202
|
-
self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
api_register = ApiRegistry()
|
|
@@ -1,212 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
|
|
18
|
-
from mindspore import Tensor, mint, ops
|
|
19
|
-
from mindspore.common._stub_tensor import StubTensor
|
|
20
|
-
from mindspore.communication import comm_func
|
|
21
|
-
from mindspore.mint.nn import functional
|
|
22
|
-
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
|
-
from msprobe.core.common.file_utils import load_yaml
|
|
25
|
-
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
-
from msprobe.mindspore.common.utils import is_mindtorch
|
|
27
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
28
|
-
|
|
29
|
-
if is_mindtorch():
|
|
30
|
-
import torch
|
|
31
|
-
import torch_npu
|
|
32
|
-
|
|
33
|
-
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
34
|
-
yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
|
|
35
|
-
torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class HOOKTensor(object):
|
|
39
|
-
pass
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class HOOKStubTensor(object):
|
|
43
|
-
pass
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class HOOKFunctionalOP(object):
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class HOOKMintOP(object):
|
|
51
|
-
pass
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class HOOKMintNNFunctionalOP(object):
|
|
55
|
-
pass
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class HOOKDistributedOP(object):
|
|
59
|
-
pass
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class HOOKTorchOP(object):
|
|
63
|
-
pass
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class HOOKTorchTensor(object):
|
|
67
|
-
pass
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
class HOOKTorchFunctionalOP(object):
|
|
71
|
-
pass
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class HOOKTorchDistributedOP(object):
|
|
75
|
-
pass
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class HOOKTorchNpuOP(object):
|
|
79
|
-
pass
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
class ApiTemplate(HOOKCell):
|
|
83
|
-
def __init__(self, api_name, api_dict, prefix, hook):
|
|
84
|
-
self.api_name = api_name
|
|
85
|
-
self.api_func = api_dict[api_name]
|
|
86
|
-
self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
87
|
-
super().__init__(hook)
|
|
88
|
-
|
|
89
|
-
@staticmethod
|
|
90
|
-
def async_to_sync(output):
|
|
91
|
-
# Fake handle, used to return after the CommHandle executes the wait method
|
|
92
|
-
fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
|
|
93
|
-
if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
|
|
94
|
-
output[1].wait()
|
|
95
|
-
output = (output[0], fake_handle)
|
|
96
|
-
elif hasattr(output, "wait"):
|
|
97
|
-
output.wait()
|
|
98
|
-
output = fake_handle
|
|
99
|
-
return output
|
|
100
|
-
|
|
101
|
-
def construct(self, *args, **kwargs):
|
|
102
|
-
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
103
|
-
return args[0] if args else kwargs.get(Const.INPUT)
|
|
104
|
-
|
|
105
|
-
output = self.api_func(*args, **kwargs)
|
|
106
|
-
|
|
107
|
-
if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
|
|
108
|
-
if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
|
|
109
|
-
output = self.async_to_sync(output)
|
|
110
|
-
return output
|
|
111
|
-
|
|
112
|
-
def forward(self, *args, **kwargs):
|
|
113
|
-
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
114
|
-
return args[0] if args else kwargs.get(Const.INPUT)
|
|
115
|
-
return self.api_func(*args, **kwargs)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class WrapApiName:
|
|
119
|
-
def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
|
|
120
|
-
distributed_api_names):
|
|
121
|
-
self.tensor_api_names = tensor_api_names
|
|
122
|
-
self.stub_tensor_api_names = stub_tensor_api_names
|
|
123
|
-
self.ops_api_names = ops_api_names
|
|
124
|
-
self.mint_api_names = mint_api_names
|
|
125
|
-
self.mint_nn_func_api_names = mint_nn_func_api_names
|
|
126
|
-
self.distributed_api_names = distributed_api_names
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class WrapTorchApiName:
|
|
130
|
-
def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
|
|
131
|
-
self.torch_api_names = torch_api_names
|
|
132
|
-
self.tensor_api_names = tensor_api_names
|
|
133
|
-
self.functional_api_names = functional_api_names
|
|
134
|
-
self.distributed_api_names = distributed_api_names
|
|
135
|
-
self.npu_api_names = npu_api_names
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def get_wrap_api_list():
|
|
139
|
-
api_list = load_yaml(yaml_path)
|
|
140
|
-
tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
141
|
-
ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
|
|
142
|
-
mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
|
|
143
|
-
mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
|
|
144
|
-
distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
|
|
145
|
-
wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
|
|
146
|
-
set(tensor_api) & set(dir(StubTensor)),
|
|
147
|
-
set(ops_api) & set(dir(ops)),
|
|
148
|
-
set(mint_api) & set(dir(mint)),
|
|
149
|
-
set(mint_nn_func_api) & set(dir(functional)),
|
|
150
|
-
set(distributed_api) & set(dir(comm_func)))
|
|
151
|
-
return wrap_api_name
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def get_wrap_torch_api_list():
|
|
155
|
-
api_list = load_yaml(torch_yaml_path)
|
|
156
|
-
torch_api = api_list.get("torch")
|
|
157
|
-
tensor_api = api_list.get("tensor")
|
|
158
|
-
functional_api = api_list.get("functional")
|
|
159
|
-
distributed_api = api_list.get("distributed")
|
|
160
|
-
npu_api = api_list.get("torch_npu")
|
|
161
|
-
wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
|
|
162
|
-
set(tensor_api) & set(dir(torch.Tensor)),
|
|
163
|
-
set(functional_api) & set(dir(torch.nn.functional)),
|
|
164
|
-
set(distributed_api) & set(dir(torch.distributed)),
|
|
165
|
-
set(npu_api) & set(dir(torch_npu)))
|
|
166
|
-
return wrap_api_name
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def wrap_api_func(api_name, api_dict, prefix, hook):
|
|
170
|
-
def api_function(*args, **kwargs):
|
|
171
|
-
return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
|
|
172
|
-
return api_function
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
|
|
176
|
-
for api_name in api_list:
|
|
177
|
-
if callable(api_dict[api_name]):
|
|
178
|
-
setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook))
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def setup_hooks(hook):
|
|
182
|
-
if is_mindtorch():
|
|
183
|
-
torch_wrap_api_name = get_wrap_torch_api_list()
|
|
184
|
-
wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
|
|
185
|
-
{f: getattr(torch, f) for f in dir(torch)},
|
|
186
|
-
MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
|
|
187
|
-
wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
|
|
188
|
-
{f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
|
|
189
|
-
MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
|
|
190
|
-
wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
|
|
191
|
-
{f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
|
|
192
|
-
MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
|
|
193
|
-
wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
|
|
194
|
-
{f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
|
|
195
|
-
MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
|
|
196
|
-
wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
|
|
197
|
-
MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
|
|
198
|
-
return
|
|
199
|
-
|
|
200
|
-
wrap_api_name = get_wrap_api_list()
|
|
201
|
-
wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
|
|
202
|
-
MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
|
|
203
|
-
wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)},
|
|
204
|
-
MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor)
|
|
205
|
-
wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)},
|
|
206
|
-
MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP)
|
|
207
|
-
wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)},
|
|
208
|
-
MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
|
|
209
|
-
wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
|
|
210
|
-
MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
|
|
211
|
-
wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
|
|
212
|
-
MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)
|
|
@@ -1,166 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
import torch.distributed as dist
|
|
18
|
-
|
|
19
|
-
from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten
|
|
20
|
-
from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops
|
|
21
|
-
from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops
|
|
22
|
-
from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops
|
|
23
|
-
from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops
|
|
24
|
-
from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops
|
|
25
|
-
from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops
|
|
26
|
-
from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
|
|
27
|
-
from msprobe.core.common.const import Const
|
|
28
|
-
|
|
29
|
-
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
30
|
-
|
|
31
|
-
if not is_gpu:
|
|
32
|
-
import torch_npu
|
|
33
|
-
from . import wrap_npu_custom
|
|
34
|
-
from .wrap_npu_custom import get_npu_ops
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class ApiRegistry:
|
|
38
|
-
def __init__(self):
|
|
39
|
-
self.tensor_ori_attr = {}
|
|
40
|
-
self.torch_ori_attr = {}
|
|
41
|
-
self.functional_ori_attr = {}
|
|
42
|
-
self.distributed_ori_attr = {}
|
|
43
|
-
self.npu_distributed_ori_attr = {}
|
|
44
|
-
self.vf_ori_attr = {}
|
|
45
|
-
self.aten_ori_attr = {}
|
|
46
|
-
self.torch_npu_ori_attr = {}
|
|
47
|
-
|
|
48
|
-
self.tensor_hook_attr = {}
|
|
49
|
-
self.torch_hook_attr = {}
|
|
50
|
-
self.functional_hook_attr = {}
|
|
51
|
-
self.distributed_hook_attr = {}
|
|
52
|
-
self.npu_distributed_hook_attr = {}
|
|
53
|
-
self.vf_hook_attr = {}
|
|
54
|
-
self.aten_hook_attr = {}
|
|
55
|
-
self.torch_npu_hook_attr = {}
|
|
56
|
-
|
|
57
|
-
@staticmethod
|
|
58
|
-
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
59
|
-
for api in api_list:
|
|
60
|
-
if '.' in api:
|
|
61
|
-
sub_module_name, sub_op = api.rsplit('.', 1)
|
|
62
|
-
sub_module = getattr(ori_api_group, sub_module_name)
|
|
63
|
-
api_ori_attr[api] = getattr(sub_module, sub_op)
|
|
64
|
-
else:
|
|
65
|
-
api_ori_attr[api] = getattr(ori_api_group, api)
|
|
66
|
-
|
|
67
|
-
@staticmethod
|
|
68
|
-
def set_api_attr(api_group, attr_dict):
|
|
69
|
-
for api, api_attr in attr_dict.items():
|
|
70
|
-
if '.' in api:
|
|
71
|
-
sub_module_name, sub_op = api.rsplit('.', 1)
|
|
72
|
-
sub_module = getattr(api_group, sub_module_name, None)
|
|
73
|
-
if sub_module is not None:
|
|
74
|
-
setattr(sub_module, sub_op, api_attr)
|
|
75
|
-
else:
|
|
76
|
-
setattr(api_group, api, api_attr)
|
|
77
|
-
|
|
78
|
-
def api_modularity(self):
|
|
79
|
-
self.set_api_attr(torch.Tensor, self.tensor_hook_attr)
|
|
80
|
-
self.set_api_attr(torch, self.torch_hook_attr)
|
|
81
|
-
self.set_api_attr(torch.nn.functional, self.functional_hook_attr)
|
|
82
|
-
self.set_api_attr(dist, self.distributed_hook_attr)
|
|
83
|
-
self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr)
|
|
84
|
-
if not is_gpu and not torch_without_guard_version:
|
|
85
|
-
self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr)
|
|
86
|
-
self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr)
|
|
87
|
-
if torch_version_above_2:
|
|
88
|
-
self.set_api_attr(torch.ops.aten, self.aten_hook_attr)
|
|
89
|
-
self.set_api_attr(torch._VF, self.vf_hook_attr)
|
|
90
|
-
if not is_gpu:
|
|
91
|
-
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
92
|
-
|
|
93
|
-
def api_originality(self):
|
|
94
|
-
self.set_api_attr(torch.Tensor, self.tensor_ori_attr)
|
|
95
|
-
self.set_api_attr(torch, self.torch_ori_attr)
|
|
96
|
-
self.set_api_attr(torch.nn.functional, self.functional_ori_attr)
|
|
97
|
-
self.set_api_attr(dist, self.distributed_ori_attr)
|
|
98
|
-
self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr)
|
|
99
|
-
if not is_gpu and not torch_without_guard_version:
|
|
100
|
-
self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr)
|
|
101
|
-
self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr)
|
|
102
|
-
if torch_version_above_2:
|
|
103
|
-
self.set_api_attr(torch.ops.aten, self.aten_ori_attr)
|
|
104
|
-
self.set_api_attr(torch._VF, self.vf_ori_attr)
|
|
105
|
-
if not is_gpu:
|
|
106
|
-
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
107
|
-
|
|
108
|
-
def initialize_hook(self, hook, online_run_ut=False):
|
|
109
|
-
"""
|
|
110
|
-
initialize_hook
|
|
111
|
-
Args:
|
|
112
|
-
hook (_type_): initialize_hook
|
|
113
|
-
online_run_ut (bool): default False, whether online run_ut or not.
|
|
114
|
-
If online_run_ut is True, the hook will not wrap the aten ops.
|
|
115
|
-
"""
|
|
116
|
-
self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr)
|
|
117
|
-
wrap_tensor.wrap_tensor_ops_and_bind(hook)
|
|
118
|
-
for attr_name in dir(wrap_tensor.HOOKTensor):
|
|
119
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
120
|
-
self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name)
|
|
121
|
-
|
|
122
|
-
self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr)
|
|
123
|
-
wrap_torch.wrap_torch_ops_and_bind(hook)
|
|
124
|
-
for attr_name in dir(wrap_torch.HOOKTorchOP):
|
|
125
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
126
|
-
self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name)
|
|
127
|
-
|
|
128
|
-
self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr)
|
|
129
|
-
wrap_functional.wrap_functional_ops_and_bind(hook)
|
|
130
|
-
for attr_name in dir(wrap_functional.HOOKFunctionalOP):
|
|
131
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
132
|
-
self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name)
|
|
133
|
-
|
|
134
|
-
self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr)
|
|
135
|
-
wrap_distributed.wrap_distributed_ops_and_bind(hook)
|
|
136
|
-
if not is_gpu and not torch_without_guard_version:
|
|
137
|
-
self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr)
|
|
138
|
-
for attr_name in dir(wrap_distributed.HOOKDistributedOP):
|
|
139
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
140
|
-
self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name)
|
|
141
|
-
if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api:
|
|
142
|
-
self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP,
|
|
143
|
-
attr_name)
|
|
144
|
-
|
|
145
|
-
if torch_version_above_2 and not online_run_ut:
|
|
146
|
-
self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr)
|
|
147
|
-
wrap_aten.wrap_aten_ops_and_bind(hook)
|
|
148
|
-
for attr_name in dir(wrap_aten.HOOKAtenOP):
|
|
149
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
150
|
-
self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name)
|
|
151
|
-
|
|
152
|
-
self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr)
|
|
153
|
-
wrap_vf.wrap_vf_ops_and_bind(hook)
|
|
154
|
-
for attr_name in dir(wrap_vf.HOOKVfOP):
|
|
155
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
156
|
-
self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name)
|
|
157
|
-
|
|
158
|
-
if not is_gpu:
|
|
159
|
-
self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr)
|
|
160
|
-
wrap_npu_custom.wrap_npu_ops_and_bind(hook)
|
|
161
|
-
for attr_name in dir(wrap_npu_custom.HOOKNpuOP):
|
|
162
|
-
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
163
|
-
self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
api_register = ApiRegistry()
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
from functools import wraps
|
|
18
|
-
import torch.distributed as dist
|
|
19
|
-
|
|
20
|
-
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
21
|
-
from msprobe.pytorch.common.utils import torch_device_guard
|
|
22
|
-
from msprobe.core.common.const import Const
|
|
23
|
-
from msprobe.core.common.file_utils import load_yaml
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
27
|
-
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
distributed_func = {}
|
|
31
|
-
for f in dir(dist):
|
|
32
|
-
distributed_func[f] = getattr(dist, f)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def get_distributed_ops():
|
|
36
|
-
_all_distributed_ops = dir(dist)
|
|
37
|
-
yaml_data = load_yaml(yaml_path)
|
|
38
|
-
wrap_distributed_ops = yaml_data.get('distributed')
|
|
39
|
-
return set(wrap_distributed_ops) & set(_all_distributed_ops)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class HOOKDistributedOP(object):
|
|
43
|
-
pass
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class DistributedOPTemplate(HOOKModule):
|
|
47
|
-
def __init__(self, op_name, build_hook):
|
|
48
|
-
self.op_name_ = op_name
|
|
49
|
-
self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
|
|
50
|
-
super().__init__(build_hook)
|
|
51
|
-
if not self.stop_hook:
|
|
52
|
-
self.op_is_distributed = True
|
|
53
|
-
|
|
54
|
-
@torch_device_guard
|
|
55
|
-
def forward(self, *args, **kwargs):
|
|
56
|
-
handle = distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
57
|
-
if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
|
|
58
|
-
if handle and hasattr(handle, 'wait'):
|
|
59
|
-
handle.wait()
|
|
60
|
-
return handle
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def wrap_distributed_op(op_name, hook):
|
|
64
|
-
@wraps(DistributedOPTemplate)
|
|
65
|
-
def distributed_op_template(*args, **kwargs):
|
|
66
|
-
return DistributedOPTemplate(op_name, hook)(*args, **kwargs)
|
|
67
|
-
|
|
68
|
-
distributed_op_template.__name__ = op_name
|
|
69
|
-
return distributed_op_template
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def wrap_distributed_ops_and_bind(hook):
|
|
73
|
-
_distributed_ops = get_distributed_ops()
|
|
74
|
-
for op_name in _distributed_ops:
|
|
75
|
-
setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook))
|