mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
2
|
# All rights reserved.
|
|
5
3
|
#
|
|
6
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
5
|
# you may not use this file except in compliance with the License.
|
|
8
6
|
# You may obtain a copy of the License at
|
|
9
7
|
#
|
|
@@ -18,8 +16,8 @@
|
|
|
18
16
|
import os
|
|
19
17
|
from collections import namedtuple
|
|
20
18
|
import re
|
|
21
|
-
import torch
|
|
22
19
|
|
|
20
|
+
import torch
|
|
23
21
|
try:
|
|
24
22
|
import torch_npu
|
|
25
23
|
except ImportError:
|
|
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
|
33
31
|
from msprobe.core.common.file_utils import FileChecker
|
|
34
32
|
from msprobe.core.common.log import logger
|
|
35
33
|
from msprobe.core.common.utils import CompareException
|
|
34
|
+
from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
|
|
36
35
|
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
37
|
-
|
|
38
|
-
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
39
|
-
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
40
|
-
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
36
|
+
|
|
41
37
|
|
|
42
38
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
43
39
|
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
@@ -108,17 +104,30 @@ def exec_api(exec_params):
|
|
|
108
104
|
kwargs = exec_params.kwargs
|
|
109
105
|
is_autocast = exec_params.is_autocast
|
|
110
106
|
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if api_type
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
107
|
+
out = None
|
|
108
|
+
|
|
109
|
+
prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
|
|
110
|
+
if not prefix_map or api_type not in prefix_map.values() or \
|
|
111
|
+
api_type not in (
|
|
112
|
+
Const.FUNCTIONAL_API_TYPE_PREFIX,
|
|
113
|
+
Const.TENSOR_API_TYPE_PREFIX,
|
|
114
|
+
Const.TORCH_API_TYPE_PREFIX,
|
|
115
|
+
Const.ATEN_API_TYPE_PREFIX,
|
|
116
|
+
Const.NPU_API_TYPE_PREFIX
|
|
117
|
+
):
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
if api_type == Const.ATEN_API_TYPE_PREFIX:
|
|
119
121
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
+
else:
|
|
123
|
+
api_register = get_api_register()
|
|
124
|
+
api_register.initialize_hook(None)
|
|
125
|
+
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
126
|
+
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
127
|
+
if api_func is None:
|
|
128
|
+
return out
|
|
129
|
+
|
|
130
|
+
torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
|
|
122
131
|
if is_autocast:
|
|
123
132
|
with autocast(dtype=autocast_dtype):
|
|
124
133
|
out = torch_api.forward(*args, **kwargs)
|
|
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
|
|
|
27
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
28
28
|
from msprobe.core.common.file_utils import remove_path
|
|
29
29
|
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
|
|
31
32
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
32
33
|
|
|
@@ -168,11 +169,12 @@ class ATTL:
|
|
|
168
169
|
return buffer
|
|
169
170
|
|
|
170
171
|
|
|
172
|
+
@recursion_depth_decorator("move2device_exec")
|
|
171
173
|
def move2device_exec(obj, device):
|
|
172
174
|
if isinstance(obj, (tuple, list)):
|
|
173
175
|
data_list = [move2device_exec(val, device) for val in obj]
|
|
174
176
|
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
175
|
-
if isinstance(obj, dict):
|
|
177
|
+
if isinstance(obj, dict):
|
|
176
178
|
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
177
179
|
elif isinstance(obj, torch.Tensor):
|
|
178
180
|
obj = obj.detach()
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import namedtuple
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t'])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _output_m_compute(m, beta1_broad, grad):
|
|
24
|
+
"""
|
|
25
|
+
_output_m_compute
|
|
26
|
+
do compute m_t = m + (beta1 - 1) * (m - grad)
|
|
27
|
+
"""
|
|
28
|
+
input_dtype = m.dtype
|
|
29
|
+
|
|
30
|
+
sneg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
31
|
+
sneg_one = sneg_one.to(beta1_broad.device)
|
|
32
|
+
|
|
33
|
+
# `formula; beta1 -1`
|
|
34
|
+
vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
|
|
35
|
+
|
|
36
|
+
# `formula; m - grad`
|
|
37
|
+
vsub_m_grad = torch.sub(m, grad)
|
|
38
|
+
|
|
39
|
+
# `formula; (beta1 - 1) * (m - grad)`
|
|
40
|
+
vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
|
|
41
|
+
|
|
42
|
+
# `formula; m_t = m + (beta1 - 1) * (m - grad)`
|
|
43
|
+
m_t = torch.add(m, vmul_m)
|
|
44
|
+
|
|
45
|
+
return m_t
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _output_v_compute(v, beta2, grad):
|
|
49
|
+
"""
|
|
50
|
+
_output_v_compute
|
|
51
|
+
do compute v_t = v + (1 - beta2)*(grad*grad -v)
|
|
52
|
+
"""
|
|
53
|
+
input_dtype = v.dtype
|
|
54
|
+
|
|
55
|
+
sneg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
56
|
+
|
|
57
|
+
# `formula; broadcast beta2 to vector`
|
|
58
|
+
beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
|
|
59
|
+
beta2_broad = beta2_tensor.expand_as(v)
|
|
60
|
+
|
|
61
|
+
# `formula; beta2 - 1`
|
|
62
|
+
vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
|
|
63
|
+
vsub_beta2_1 = vsub_beta2_1.to(v.device)
|
|
64
|
+
|
|
65
|
+
# `formula; grad * grad`
|
|
66
|
+
vmul_grad_grad = torch.mul(grad, grad)
|
|
67
|
+
|
|
68
|
+
# `formula; (v - grad*grad)`
|
|
69
|
+
vsub_v_grad = torch.sub(v, vmul_grad_grad)
|
|
70
|
+
|
|
71
|
+
# `formula; (beta2 -1) * (v - grad * grad)`
|
|
72
|
+
vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
|
|
73
|
+
|
|
74
|
+
# `formula; v_t = v + (beta2 - 1) * (v - grad * grad)`
|
|
75
|
+
v_t = torch.add(v, vmul_grad)
|
|
76
|
+
|
|
77
|
+
return v_t
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
|
|
81
|
+
"""
|
|
82
|
+
_inner_lr_compute
|
|
83
|
+
`formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)`
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
input_dtype = compute_shape_tensor.dtype
|
|
87
|
+
|
|
88
|
+
s_one = torch.ones((1), dtype=input_dtype)
|
|
89
|
+
|
|
90
|
+
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
91
|
+
|
|
92
|
+
# `formula; (1 - beta2_power)`
|
|
93
|
+
v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
|
|
94
|
+
v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
|
|
95
|
+
|
|
96
|
+
# `formula; sqrt(1 - beta2_power)`
|
|
97
|
+
v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
|
|
98
|
+
|
|
99
|
+
# `formula; (1 - beta1_power)`
|
|
100
|
+
v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
|
|
101
|
+
v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
|
|
102
|
+
|
|
103
|
+
# `formula; learning_rate * (sqrt(1-beta2_power)`
|
|
104
|
+
res = torch.mul(lr, v_sqrt_beta2_power)
|
|
105
|
+
|
|
106
|
+
# `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)`
|
|
107
|
+
res = torch.div(res, v_add_beta1_power)
|
|
108
|
+
return res.expand_as(compute_shape_tensor)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
|
|
112
|
+
"""
|
|
113
|
+
(epsilon + sqrt(v_t) )
|
|
114
|
+
"""
|
|
115
|
+
# `formula; sqrt(v_t)`
|
|
116
|
+
sqrt_vt = torch.sqrt(v_t)
|
|
117
|
+
|
|
118
|
+
# `formula; broadcast epsilon to vector`
|
|
119
|
+
input_dtype = v_t.dtype
|
|
120
|
+
epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
|
|
121
|
+
epsilon_broad = epsilon_tensor.expand_as(v_t)
|
|
122
|
+
epsilon_broad = epsilon_broad.to(sqrt_vt.device)
|
|
123
|
+
|
|
124
|
+
# `formula; epsilon + sqrt(v_t)`
|
|
125
|
+
v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
|
|
126
|
+
|
|
127
|
+
return v_add_sqrt_v
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _output_var_t_compute_use_nesterov(varparams):
|
|
131
|
+
"""
|
|
132
|
+
_output_var_t_compute_use_nesterov
|
|
133
|
+
`formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
|
|
134
|
+
`formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
|
|
135
|
+
"""
|
|
136
|
+
var = varparams.var
|
|
137
|
+
lr_t = varparams.lr_t
|
|
138
|
+
m_t = varparams.m_t
|
|
139
|
+
beta1_broad = varparams.beta1_broad
|
|
140
|
+
grad = varparams.grad
|
|
141
|
+
epsilon = varparams.epsilon
|
|
142
|
+
v_t = varparams.v_t
|
|
143
|
+
|
|
144
|
+
input_dtype = var.dtype
|
|
145
|
+
|
|
146
|
+
s_one = torch.ones((1), dtype=input_dtype)
|
|
147
|
+
|
|
148
|
+
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
|
|
149
|
+
|
|
150
|
+
# `formula; m_t * beta1`
|
|
151
|
+
v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
|
|
152
|
+
|
|
153
|
+
# `formula; 1 -beta1`
|
|
154
|
+
v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
|
|
155
|
+
vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
|
|
156
|
+
|
|
157
|
+
# `formula; (1-beta1)* grad`
|
|
158
|
+
v_mul_grad = torch.mul(vsub_1_beta1, grad)
|
|
159
|
+
|
|
160
|
+
# `formula; (m_t*beta1 + (1 - beta1)*grad)`
|
|
161
|
+
v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
|
|
162
|
+
|
|
163
|
+
# `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)`
|
|
164
|
+
# broadcast lr_t to vector
|
|
165
|
+
|
|
166
|
+
lrt_broad = lr_t.expand_as(var)
|
|
167
|
+
v_mul_left = torch.mul(lrt_broad, v_div_left)
|
|
168
|
+
|
|
169
|
+
# `formula; (epsilon + sqrt(v_t))`
|
|
170
|
+
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
|
|
171
|
+
|
|
172
|
+
# `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))`
|
|
173
|
+
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
|
|
174
|
+
|
|
175
|
+
# `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))`
|
|
176
|
+
v_t = torch.sub(var, v_div_res)
|
|
177
|
+
|
|
178
|
+
return v_t
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
|
|
182
|
+
"""
|
|
183
|
+
_output_var_t_compute
|
|
184
|
+
`var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
|
|
185
|
+
"""
|
|
186
|
+
# `formula; lr_t * m_t`
|
|
187
|
+
lr_t = lr_t.to(m_t.device)
|
|
188
|
+
v_mul_left = torch.mul(lr_t, m_t)
|
|
189
|
+
|
|
190
|
+
# `formula; (epsilon + sqrt(v_t))`
|
|
191
|
+
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
|
|
192
|
+
|
|
193
|
+
# `formula; lr_t * m_t /(epsilon + sqrt(v_t))`
|
|
194
|
+
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
|
|
195
|
+
|
|
196
|
+
# `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))`
|
|
197
|
+
v_t = torch.sub(var, v_div_res)
|
|
198
|
+
|
|
199
|
+
return v_t
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out):
|
|
203
|
+
var, m, v = out
|
|
204
|
+
input_dtype = m.dtype
|
|
205
|
+
beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device)
|
|
206
|
+
beta1_broad = beta1_tensor.expand_as(m)
|
|
207
|
+
m_t = _output_m_compute(m, beta1_broad, grad)
|
|
208
|
+
v_t = _output_v_compute(v, beta2, grad)
|
|
209
|
+
lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad)
|
|
210
|
+
if use_nesterov:
|
|
211
|
+
var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t)
|
|
212
|
+
var_t = _output_var_t_compute_use_nesterov(var_params)
|
|
213
|
+
else:
|
|
214
|
+
var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t)
|
|
215
|
+
return var_t, m_t, v_t
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def npu_group_norm_silu(x, gama, beta, group, eps):
|
|
20
|
+
if len(x.shape) != 4:
|
|
21
|
+
raise ValueError("x shape should be (N, C, H, W)")
|
|
22
|
+
res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
|
|
23
|
+
res = list(res)
|
|
24
|
+
if not res:
|
|
25
|
+
raise ValueError("run native_group_norm failed")
|
|
26
|
+
res[0] = torch.nn.functional.silu(res[0])
|
|
27
|
+
return res
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,7 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
import torch
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
|
|
19
|
+
def npu_mish(x):
|
|
20
|
+
mish = torch.nn.Mish()
|
|
21
|
+
return mish(x)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def softmax_func(x, axis=None):
|
|
21
|
+
x = x.float()
|
|
22
|
+
x_max = x.max(dim=axis, keepdims=True).values
|
|
23
|
+
x_sub = x - x_max
|
|
24
|
+
y = torch.exp(x_sub)
|
|
25
|
+
x_sum = y.sum(dim=axis, keepdims=True)
|
|
26
|
+
ans = 0 if (x_sum == 0).any() else y / x_sum
|
|
27
|
+
return ans
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
31
|
+
input_dtype = x.dtype
|
|
32
|
+
if x.dim() < 1:
|
|
33
|
+
raise ValueError("Input x must have at least 1 dimensions.")
|
|
34
|
+
num_expert = x.shape[-1]
|
|
35
|
+
softmax = softmax_func(x, -1)
|
|
36
|
+
softmax = softmax.to(input_dtype)
|
|
37
|
+
expert_idx = torch.argsort(-softmax, dim=-1, stable=True)
|
|
38
|
+
expert_idx = expert_idx[:, :k]
|
|
39
|
+
y = torch.gather(softmax, index=expert_idx, dim=-1)
|
|
40
|
+
if finished_optional is not None:
|
|
41
|
+
if finished_optional.dim() < 1:
|
|
42
|
+
raise ValueError("Finished_optional must have at least 1 dimensions.")
|
|
43
|
+
finished_optional = finished_optional.view(finished_optional.shape[0], 1)
|
|
44
|
+
finished_optional = finished_optional.expand(-1, k)
|
|
45
|
+
expert_idx = torch.where(finished_optional, num_expert, expert_idx)
|
|
46
|
+
if y.dim() < 2:
|
|
47
|
+
raise ValueError("Variable y must have at least 2 dimensions.")
|
|
48
|
+
row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
|
|
49
|
+
|
|
50
|
+
return y, expert_idx, row_idx
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def npu_sort_v2(x, dim=-1, descending=False, out=None):
|
|
20
|
+
y, _ = torch.sort(x, dim=dim, descending=descending)
|
|
21
|
+
return y
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -18,6 +18,7 @@ import os
|
|
|
18
18
|
import pickle
|
|
19
19
|
import random
|
|
20
20
|
import stat
|
|
21
|
+
import inspect
|
|
21
22
|
from functools import wraps
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
@@ -27,7 +28,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
|
27
28
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
28
29
|
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
29
30
|
from msprobe.core.common.log import logger
|
|
30
|
-
from msprobe.core.common.utils import check_seed_all
|
|
31
|
+
from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
|
|
31
32
|
from packaging import version
|
|
32
33
|
|
|
33
34
|
try:
|
|
@@ -56,7 +57,7 @@ def parameter_adapter(func):
|
|
|
56
57
|
|
|
57
58
|
@wraps(func)
|
|
58
59
|
def inner(self, *args, **kwargs):
|
|
59
|
-
if self.
|
|
60
|
+
if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
|
|
60
61
|
input_tensor = args[0]
|
|
61
62
|
indices = args[1]
|
|
62
63
|
if indices.dtype == torch.uint8:
|
|
@@ -76,7 +77,7 @@ def parameter_adapter(func):
|
|
|
76
77
|
else:
|
|
77
78
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
78
79
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
79
|
-
if self.
|
|
80
|
+
if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
|
|
80
81
|
return False
|
|
81
82
|
return func(self, *args, **kwargs)
|
|
82
83
|
|
|
@@ -260,6 +261,10 @@ class Const:
|
|
|
260
261
|
NPU = 'NPU'
|
|
261
262
|
DISTRIBUTED = 'Distributed'
|
|
262
263
|
|
|
264
|
+
HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
|
|
265
|
+
FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
|
|
266
|
+
FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
|
|
267
|
+
|
|
263
268
|
RAISE_PRECISION = {
|
|
264
269
|
torch.float16: torch.float32,
|
|
265
270
|
torch.bfloat16: torch.float32,
|
|
@@ -402,3 +407,91 @@ def load_api_data(api_data_bytes):
|
|
|
402
407
|
except Exception as e:
|
|
403
408
|
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
404
409
|
return buffer
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def is_recomputation():
|
|
413
|
+
"""Check if the current operation is in the re-computation phase.
|
|
414
|
+
|
|
415
|
+
This function inspects the current call stack to indicate whether the current operation is in the
|
|
416
|
+
re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
|
|
417
|
+
megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
|
|
418
|
+
mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
|
|
419
|
+
file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the
|
|
420
|
+
'torch/utils/checkpoint.py' file.
|
|
421
|
+
|
|
422
|
+
Returns:
|
|
423
|
+
bool: True if in the re-computation phase, False otherwise.
|
|
424
|
+
"""
|
|
425
|
+
backward_function_indices = []
|
|
426
|
+
try:
|
|
427
|
+
call_stack = inspect.stack()
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
|
|
430
|
+
return False
|
|
431
|
+
|
|
432
|
+
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
433
|
+
for frame_info in call_stack:
|
|
434
|
+
if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'):
|
|
435
|
+
del call_stack
|
|
436
|
+
return True
|
|
437
|
+
|
|
438
|
+
# Identify indices in the call stack where the specific function is being executed
|
|
439
|
+
for idx, frame_info in enumerate(call_stack):
|
|
440
|
+
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
|
|
441
|
+
backward_function_indices.append(idx)
|
|
442
|
+
|
|
443
|
+
# Check if the execution is within 'torch/autograd/function.py' file
|
|
444
|
+
for idx in backward_function_indices:
|
|
445
|
+
# The Megatron and MindSpeed L0&L1 scenes
|
|
446
|
+
if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
|
|
447
|
+
del call_stack
|
|
448
|
+
return True
|
|
449
|
+
# The latest MindSpeed L2 and ModelLink scenes
|
|
450
|
+
if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
|
|
451
|
+
del call_stack
|
|
452
|
+
return True
|
|
453
|
+
|
|
454
|
+
del call_stack
|
|
455
|
+
return False
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def check_save_param(variable, name, save_backward):
|
|
459
|
+
# try catch this api to skip invalid call
|
|
460
|
+
valid_data_types = tuple([torch.Tensor, int, float, str])
|
|
461
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
462
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
463
|
+
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
464
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
465
|
+
"Skip current save process.")
|
|
466
|
+
raise ValueError
|
|
467
|
+
if not isinstance(name, str):
|
|
468
|
+
logger.warning("PrecisionDebugger.save name not valid, "
|
|
469
|
+
"should be string. "
|
|
470
|
+
"skip current save process.")
|
|
471
|
+
raise ValueError
|
|
472
|
+
if not isinstance(save_backward, bool):
|
|
473
|
+
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
474
|
+
"should be bool. "
|
|
475
|
+
"Skip current save process.")
|
|
476
|
+
raise ValueError
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def replace_last_occurrence(text, old, new):
|
|
480
|
+
if text is None:
|
|
481
|
+
return text
|
|
482
|
+
index = text.rfind(old)
|
|
483
|
+
if index != -1:
|
|
484
|
+
return text[:index] + text[index:].replace(old, new, 1)
|
|
485
|
+
return text
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def is_hifloat8_tensor(tensor):
|
|
489
|
+
if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
|
|
490
|
+
return True
|
|
491
|
+
return False
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def is_float8_tensor(tensor):
|
|
495
|
+
if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
|
|
496
|
+
return True
|
|
497
|
+
return is_hifloat8_tensor(tensor)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -26,7 +26,7 @@ class DebuggerConfig:
|
|
|
26
26
|
self.task = task or common_config.task or Const.STATISTICS
|
|
27
27
|
self.rank = common_config.rank if common_config.rank else []
|
|
28
28
|
self.step = common_config.step if common_config.step else []
|
|
29
|
-
self.level = level or common_config.level or
|
|
29
|
+
self.level = level or common_config.level or Const.LEVEL_L1
|
|
30
30
|
self.enable_dataloader = common_config.enable_dataloader
|
|
31
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
32
32
|
self.list = task_config.list if task_config.list else []
|
|
@@ -36,10 +36,6 @@ class DebuggerConfig:
|
|
|
36
36
|
self.framework = Const.PT_FRAMEWORK
|
|
37
37
|
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
38
38
|
|
|
39
|
-
if self.level == Const.LEVEL_L2:
|
|
40
|
-
self.is_backward_kernel_dump = False
|
|
41
|
-
self._check_and_adjust_config_with_l2()
|
|
42
|
-
|
|
43
39
|
if self.task == Const.FREE_BENCHMARK:
|
|
44
40
|
self.fuzz_device = task_config.fuzz_device
|
|
45
41
|
self.handler_type = task_config.handler_type
|
|
@@ -65,6 +61,10 @@ class DebuggerConfig:
|
|
|
65
61
|
|
|
66
62
|
self.check()
|
|
67
63
|
|
|
64
|
+
if self.level == Const.LEVEL_L2:
|
|
65
|
+
self.is_backward_kernel_dump = False
|
|
66
|
+
self._check_and_adjust_config_with_l2()
|
|
67
|
+
|
|
68
68
|
def check_kwargs(self):
|
|
69
69
|
if self.task and self.task not in Const.TASK_LIST:
|
|
70
70
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
@@ -78,6 +78,16 @@ class DebuggerConfig:
|
|
|
78
78
|
if not isinstance(self.async_dump, bool):
|
|
79
79
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
80
|
f"The parameters async_dump should be bool.")
|
|
81
|
+
if self.async_dump and self.task == Const.TENSOR and not self.list:
|
|
82
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
83
|
+
f"The parameters async_dump is true in tensor task, the parameters list cannot be "
|
|
84
|
+
f"empty.")
|
|
85
|
+
if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
86
|
+
logger.warning_on_rank_0(
|
|
87
|
+
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
88
|
+
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
89
|
+
)
|
|
90
|
+
self.level = Const.LEVEL_MIX
|
|
81
91
|
|
|
82
92
|
def check(self):
|
|
83
93
|
self.check_kwargs()
|
|
@@ -93,10 +103,10 @@ class DebuggerConfig:
|
|
|
93
103
|
logger.error_on_rank_0(
|
|
94
104
|
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
95
105
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
96
|
-
|
|
106
|
+
|
|
97
107
|
instance.model = start_model if start_model is not None else instance.model
|
|
98
108
|
if isinstance(instance.model, torch.nn.Module):
|
|
99
|
-
return
|
|
109
|
+
return
|
|
100
110
|
|
|
101
111
|
error_model = None
|
|
102
112
|
if isinstance(instance.model, (list, tuple)):
|
|
@@ -108,7 +118,7 @@ class DebuggerConfig:
|
|
|
108
118
|
error_model = instance.model
|
|
109
119
|
|
|
110
120
|
if error_model is not None:
|
|
111
|
-
error_info = (f"The 'model' parameter must be a torch.nn.
|
|
121
|
+
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
112
122
|
f"type, currently there is a {type(error_model)} type.")
|
|
113
123
|
raise MsprobeException(
|
|
114
124
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|