mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def softmax_func(x, axis=None):
|
|
21
|
+
x = x.float()
|
|
22
|
+
x_max = x.max(dim=axis, keepdims=True).values
|
|
23
|
+
x_sub = x - x_max
|
|
24
|
+
y = torch.exp(x_sub)
|
|
25
|
+
x_sum = y.sum(dim=axis, keepdims=True)
|
|
26
|
+
ans = 0 if (x_sum == 0).any() else y / x_sum
|
|
27
|
+
return ans
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
31
|
+
input_dtype = x.dtype
|
|
32
|
+
num_expert = x.shape[-1]
|
|
33
|
+
softmax = softmax_func(x, -1)
|
|
34
|
+
softmax = softmax.to(input_dtype)
|
|
35
|
+
expert_idx = torch.argsort(-softmax, dim=-1, stable=True)
|
|
36
|
+
expert_idx = expert_idx[:, :k]
|
|
37
|
+
y = torch.gather(softmax, index=expert_idx, dim=-1)
|
|
38
|
+
if finished_optional is not None:
|
|
39
|
+
finished_optional = finished_optional.view(finished_optional.shape[0], 1)
|
|
40
|
+
finished_optional = finished_optional.expand(-1, k)
|
|
41
|
+
expert_idx = torch.where(finished_optional, num_expert, expert_idx)
|
|
42
|
+
row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
|
|
43
|
+
|
|
44
|
+
return y, expert_idx, row_idx
|
|
@@ -30,6 +30,7 @@
|
|
|
30
30
|
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
+
from collections import namedtuple
|
|
33
34
|
import torch
|
|
34
35
|
import numpy as np
|
|
35
36
|
from einops import rearrange
|
|
@@ -54,6 +55,14 @@ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即
|
|
|
54
55
|
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
55
56
|
|
|
56
57
|
|
|
58
|
+
FaForwardParams = namedtuple("FaForwardParams",
|
|
59
|
+
["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"])
|
|
60
|
+
FaBackwardParams = namedtuple("FaBackwardParams",
|
|
61
|
+
["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"])
|
|
62
|
+
RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
|
|
63
|
+
["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"])
|
|
64
|
+
|
|
65
|
+
|
|
57
66
|
def softmax_forward(x):
|
|
58
67
|
x_max = torch.max(x, dim=-1, keepdims=True)[0]
|
|
59
68
|
x_sub = x.sub(x_max)
|
|
@@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale):
|
|
|
99
108
|
return qk
|
|
100
109
|
|
|
101
110
|
|
|
102
|
-
def fusion_attention_forward(
|
|
111
|
+
def fusion_attention_forward(forward_params):
|
|
112
|
+
q = forward_params.q
|
|
113
|
+
k = forward_params.k
|
|
114
|
+
v = forward_params.v
|
|
115
|
+
drop_mask = forward_params.drop_mask
|
|
116
|
+
atten_mask = forward_params.atten_mask
|
|
117
|
+
pse = forward_params.pse
|
|
118
|
+
scale = forward_params.scale
|
|
119
|
+
keep_prob = forward_params.keep_prob
|
|
103
120
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
104
121
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
105
122
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr
|
|
|
110
127
|
return y, softmax_max, softmax_sum
|
|
111
128
|
|
|
112
129
|
|
|
113
|
-
def fusion_attention_backward(
|
|
130
|
+
def fusion_attention_backward(backward_params):
|
|
131
|
+
dx = backward_params.dx
|
|
132
|
+
q = backward_params.q
|
|
133
|
+
k = backward_params.k
|
|
134
|
+
v = backward_params.v
|
|
135
|
+
softmax_res = backward_params.softmax_res
|
|
136
|
+
drop_mask = backward_params.drop_mask
|
|
137
|
+
pse = backward_params.pse
|
|
138
|
+
scale = backward_params.scale
|
|
139
|
+
keep_prob = backward_params.keep_prob
|
|
114
140
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
115
141
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
116
142
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -368,11 +394,18 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
|
|
|
368
394
|
return softmax_res
|
|
369
395
|
|
|
370
396
|
|
|
371
|
-
def rebuild_softmax_by_max_sum(
|
|
397
|
+
def rebuild_softmax_by_max_sum(softmax_params):
|
|
372
398
|
"""
|
|
373
399
|
attention = softmax(QK^T/sqrt(d))V
|
|
374
400
|
softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
|
|
375
401
|
"""
|
|
402
|
+
q = softmax_params.q
|
|
403
|
+
k = softmax_params.k
|
|
404
|
+
atten_mask = softmax_params.atten_mask
|
|
405
|
+
pse = softmax_params.pse
|
|
406
|
+
scale = softmax_params.scale
|
|
407
|
+
softmax_max = softmax_params.softmax_max
|
|
408
|
+
softmax_sum = softmax_params.softmax_sum
|
|
376
409
|
logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
|
|
377
410
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
378
411
|
if softmax_max.shape[-1] == 0:
|
|
@@ -502,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
502
535
|
key = convert_to_bnsd(key, n2, input_layout)
|
|
503
536
|
value = convert_to_bnsd(value, n2, input_layout)
|
|
504
537
|
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
pse=pse, scale=scale,
|
|
508
|
-
keep_prob=keep_prob)
|
|
538
|
+
forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob)
|
|
539
|
+
out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
|
|
509
540
|
if out_golden.dim() == 5:
|
|
510
541
|
out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
|
|
511
542
|
out_golden.size(4))
|
|
@@ -546,9 +577,10 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
546
577
|
if SOFTMAX_BUILD_MODE == "QKV":
|
|
547
578
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
548
579
|
else:
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
580
|
+
softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
|
|
581
|
+
softmax_res = rebuild_softmax_by_max_sum(softmax_params)
|
|
582
|
+
backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
|
|
583
|
+
dq, dk, dv = fusion_attention_backward(backward_params)
|
|
552
584
|
|
|
553
585
|
# N不等长适配by cdy
|
|
554
586
|
if not (n1 == n2):
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def npu_sort_v2(x, dim=-1, descending=False, out=None):
|
|
20
|
+
y, _ = torch.sort(x, dim=dim, descending=descending)
|
|
21
|
+
return y
|
|
@@ -24,7 +24,8 @@ def parse_json_info_forward_backward(json_path):
|
|
|
24
24
|
real_data_path = dump_json.get("dump_data_dir")
|
|
25
25
|
dump_data = dump_json.get("data")
|
|
26
26
|
if dump_data is None:
|
|
27
|
-
raise ParseJsonException(ParseJsonException.InvalidDumpJson,
|
|
27
|
+
raise ParseJsonException(ParseJsonException.InvalidDumpJson,
|
|
28
|
+
"something wrong with dump, no data found in dump.json")
|
|
28
29
|
if not dump_data:
|
|
29
30
|
logger.warning("data field is empty, no overflow data found.")
|
|
30
31
|
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -18,6 +18,7 @@ import os
|
|
|
18
18
|
import pickle
|
|
19
19
|
import random
|
|
20
20
|
import stat
|
|
21
|
+
import inspect
|
|
21
22
|
from functools import wraps
|
|
22
23
|
|
|
23
24
|
import numpy as np
|
|
@@ -105,8 +106,49 @@ def get_rank_if_initialized():
|
|
|
105
106
|
raise DistributedNotInitializedError("torch distributed environment is not initialized")
|
|
106
107
|
|
|
107
108
|
|
|
108
|
-
def
|
|
109
|
-
|
|
109
|
+
def remove_dropout():
|
|
110
|
+
if torch.__version__ > "1.8":
|
|
111
|
+
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
112
|
+
import torch.nn.functional as F
|
|
113
|
+
from torch import _VF
|
|
114
|
+
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
115
|
+
|
|
116
|
+
def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
117
|
+
inplace: bool = False) -> torch.Tensor:
|
|
118
|
+
if has_torch_function_unary(input_tensor):
|
|
119
|
+
return handle_torch_function(
|
|
120
|
+
function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
121
|
+
if p < 0.0 or p > 1.0:
|
|
122
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
123
|
+
return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
|
|
124
|
+
|
|
125
|
+
def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
126
|
+
inplace: bool = False) -> torch.Tensor:
|
|
127
|
+
if has_torch_function_unary(input_tensor):
|
|
128
|
+
return handle_torch_function(
|
|
129
|
+
function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
130
|
+
if p < 0.0 or p > 1.0:
|
|
131
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
132
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
133
|
+
0., training)
|
|
134
|
+
|
|
135
|
+
def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
136
|
+
inplace: bool = False) -> torch.Tensor:
|
|
137
|
+
if has_torch_function_unary(input_tensor):
|
|
138
|
+
return handle_torch_function(
|
|
139
|
+
function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
140
|
+
if p < 0.0 or p > 1.0:
|
|
141
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
142
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
143
|
+
0., training)
|
|
144
|
+
|
|
145
|
+
F.dropout = function_dropout
|
|
146
|
+
F.dropout2d = function_dropout2d
|
|
147
|
+
F.dropout3d = function_dropout3d
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
151
|
+
check_seed_all(seed, mode, rm_dropout)
|
|
110
152
|
try:
|
|
111
153
|
random.seed(seed)
|
|
112
154
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
@@ -126,6 +168,8 @@ def seed_all(seed=1234, mode=False):
|
|
|
126
168
|
else:
|
|
127
169
|
torch_npu.npu.manual_seed_all(seed)
|
|
128
170
|
torch_npu.npu.manual_seed(seed)
|
|
171
|
+
if rm_dropout:
|
|
172
|
+
remove_dropout()
|
|
129
173
|
except Exception as e:
|
|
130
174
|
logger.error(f"There is an unexpected error while determinating randomness. {e}")
|
|
131
175
|
|
|
@@ -359,3 +403,73 @@ def load_api_data(api_data_bytes):
|
|
|
359
403
|
except Exception as e:
|
|
360
404
|
raise RuntimeError(f"load api_data from bytes failed") from e
|
|
361
405
|
return buffer
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def is_recomputation():
|
|
409
|
+
"""Check if the current operation is in the re-computation phase.
|
|
410
|
+
|
|
411
|
+
This function inspects the current call stack to indicate whether the current operation is in the
|
|
412
|
+
re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
|
|
413
|
+
megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
|
|
414
|
+
mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
|
|
415
|
+
file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the
|
|
416
|
+
'torch/utils/checkpoint.py' file.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
bool: True if in the re-computation phase, False otherwise.
|
|
420
|
+
"""
|
|
421
|
+
backward_function_indices = []
|
|
422
|
+
call_stack = inspect.stack()
|
|
423
|
+
|
|
424
|
+
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
425
|
+
for frame_info in call_stack:
|
|
426
|
+
if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'):
|
|
427
|
+
del call_stack
|
|
428
|
+
return True
|
|
429
|
+
|
|
430
|
+
# Identify indices in the call stack where the specific function is being executed
|
|
431
|
+
for idx, frame_info in enumerate(call_stack):
|
|
432
|
+
if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
|
|
433
|
+
backward_function_indices.append(idx)
|
|
434
|
+
|
|
435
|
+
# Check if the execution is within 'torch/autograd/function.py' file
|
|
436
|
+
for idx in backward_function_indices:
|
|
437
|
+
# The Megatron and MindSpeed L0&L1 scenes
|
|
438
|
+
if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
|
|
439
|
+
del call_stack
|
|
440
|
+
return True
|
|
441
|
+
# The latest MindSpeed L2 and ModelLink scenes
|
|
442
|
+
if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
|
|
443
|
+
del call_stack
|
|
444
|
+
return True
|
|
445
|
+
|
|
446
|
+
del call_stack
|
|
447
|
+
return False
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def check_save_param(variable, name, save_backward):
|
|
451
|
+
# try catch this api to skip invalid call
|
|
452
|
+
if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)):
|
|
453
|
+
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
454
|
+
"should be one of list, dict, torch.Tensor, int, float or string. "
|
|
455
|
+
"Skip current save process.")
|
|
456
|
+
raise ValueError
|
|
457
|
+
if not isinstance(name, str):
|
|
458
|
+
logger.warning("PrecisionDebugger.save name not valid, "
|
|
459
|
+
"should be string. "
|
|
460
|
+
"skip current save process.")
|
|
461
|
+
raise ValueError
|
|
462
|
+
if not isinstance(save_backward, bool):
|
|
463
|
+
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
464
|
+
"should be bool. "
|
|
465
|
+
"Skip current save process.")
|
|
466
|
+
raise ValueError
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def replace_last_occurrence(text, old, new):
|
|
470
|
+
if text is None:
|
|
471
|
+
return text
|
|
472
|
+
index = text.rfind(old)
|
|
473
|
+
if index != -1:
|
|
474
|
+
return text[:index] + text[index:].replace(old, new, 1)
|
|
475
|
+
return text
|
|
@@ -14,52 +14,40 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
18
|
-
check_configuration_param, set_dump_path, get_dump_mode
|
|
19
|
-
from msprobe.core.common.file_utils import create_directory
|
|
17
|
+
|
|
20
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory
|
|
20
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
21
|
+
set_dump_path
|
|
22
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
23
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
|
|
21
24
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
23
|
-
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator, compare
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
27
|
-
if kwargs.get(
|
|
29
|
+
if kwargs.get("suffix"):
|
|
28
30
|
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
29
31
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
30
|
-
|
|
31
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
32
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
33
|
-
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
32
|
+
is_print_compare_log = kwargs.get("is_print_compare_log", True)
|
|
34
33
|
# get the ranks and match by order
|
|
35
34
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
36
35
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
37
36
|
if len(npu_ranks) != len(bench_ranks):
|
|
38
|
-
logger.error(
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
logger.error(
|
|
38
|
+
"The number of ranks in the two runs are different. "
|
|
39
|
+
"Unable to match the ranks. "
|
|
40
|
+
"Please use another folder to compare or use compare() api and manually match the ranks.")
|
|
41
41
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
42
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
43
|
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
44
44
|
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
45
45
|
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
46
46
|
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
47
|
-
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
48
47
|
|
|
49
48
|
dump_result_param = {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
'is_print_compare_log': is_print_compare_log
|
|
49
|
+
"npu_json_path": npu_path,
|
|
50
|
+
"bench_json_path": bench_path,
|
|
51
|
+
"is_print_compare_log": is_print_compare_log
|
|
54
52
|
}
|
|
55
|
-
|
|
56
|
-
set_dump_path(dump_result_param)
|
|
57
|
-
dump_mode = get_dump_mode(dump_result_param)
|
|
58
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
|
|
59
|
-
create_directory(output_path)
|
|
60
|
-
check_compare_param(dump_result_param, output_path, dump_mode)
|
|
61
|
-
except (CompareException, FileCheckException) as error:
|
|
62
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
|
-
raise CompareException(error.code) from error
|
|
64
|
-
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
|
|
53
|
+
compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
@@ -14,19 +14,29 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os.path
|
|
17
|
+
|
|
17
18
|
import torch
|
|
19
|
+
|
|
18
20
|
from msprobe.core.common.const import FileCheckConst
|
|
19
|
-
from msprobe.pytorch.common.log import logger
|
|
20
21
|
from msprobe.core.common.exceptions import FileCheckException
|
|
21
|
-
from msprobe.core.compare.acc_compare import Comparator
|
|
22
|
-
from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
|
|
23
|
-
CompareException, set_dump_path, get_dump_mode
|
|
24
22
|
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
23
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
24
|
+
set_dump_path
|
|
25
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
26
|
+
from msprobe.core.compare.utils import set_stack_json_path
|
|
27
|
+
from msprobe.pytorch.common.log import logger
|
|
25
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
class PTComparator
|
|
29
|
-
def __init__(self, data_mapping=None):
|
|
31
|
+
class PTComparator(Comparator):
|
|
32
|
+
def __init__(self, mode_config, data_mapping=None):
|
|
33
|
+
super().__init__(mode_config)
|
|
34
|
+
|
|
35
|
+
self.stack_mode = mode_config.stack_mode
|
|
36
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
37
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
38
|
+
self.dump_mode = mode_config.dump_mode
|
|
39
|
+
|
|
30
40
|
self.frame_name = PTComparator.__name__
|
|
31
41
|
self.data_mapping = data_mapping
|
|
32
42
|
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
@@ -37,23 +47,24 @@ class PTComparator (Comparator):
|
|
|
37
47
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
48
|
f"{type(self.data_mapping)}")
|
|
39
49
|
|
|
40
|
-
|
|
50
|
+
@staticmethod
|
|
51
|
+
def load_mapping_file(mapping_file):
|
|
41
52
|
if isinstance(mapping_file, str):
|
|
42
53
|
mapping_dict = load_yaml(mapping_file)
|
|
43
54
|
else:
|
|
44
55
|
mapping_dict = {}
|
|
45
56
|
return mapping_dict
|
|
46
|
-
|
|
57
|
+
|
|
47
58
|
def read_npy_data(self, dir_path, file_name):
|
|
48
59
|
if not file_name:
|
|
49
60
|
return None
|
|
50
61
|
data_path = os.path.join(dir_path, file_name)
|
|
51
62
|
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
52
|
-
|
|
63
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
53
64
|
data_path = path_checker.common_check()
|
|
54
65
|
try:
|
|
55
|
-
|
|
56
|
-
|
|
66
|
+
# detach because numpy can not process gradient information
|
|
67
|
+
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
57
68
|
except RuntimeError as e:
|
|
58
69
|
# 这里捕获 load_pt 中抛出的异常
|
|
59
70
|
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
@@ -65,20 +76,29 @@ class PTComparator (Comparator):
|
|
|
65
76
|
if data_value.dtype == torch.bfloat16:
|
|
66
77
|
data_value = data_value.to(torch.float32)
|
|
67
78
|
data_value = data_value.numpy()
|
|
68
|
-
return data_value
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def compare(input_param, output_path,
|
|
79
|
+
return data_value
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compare(input_param, output_path, **kwargs):
|
|
72
83
|
try:
|
|
84
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
85
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
86
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
87
|
+
suffix = kwargs.get('suffix', '')
|
|
88
|
+
|
|
73
89
|
set_dump_path(input_param)
|
|
74
90
|
dump_mode = get_dump_mode(input_param)
|
|
91
|
+
if "stack_json_path" in input_param:
|
|
92
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
93
|
+
else:
|
|
94
|
+
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
75
95
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
76
96
|
create_directory(output_path)
|
|
77
|
-
check_compare_param(input_param, output_path, dump_mode)
|
|
78
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
97
|
+
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
79
98
|
except (CompareException, FileCheckException) as error:
|
|
80
99
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
81
100
|
raise CompareException(error.code) from error
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
101
|
+
|
|
102
|
+
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
103
|
+
pt_comparator = PTComparator(mode_config, data_mapping)
|
|
104
|
+
pt_comparator.compare_core(input_param, output_path, suffix=suffix)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -26,7 +26,7 @@ class DebuggerConfig:
|
|
|
26
26
|
self.task = task or common_config.task or Const.STATISTICS
|
|
27
27
|
self.rank = common_config.rank if common_config.rank else []
|
|
28
28
|
self.step = common_config.step if common_config.step else []
|
|
29
|
-
self.level = level or common_config.level or
|
|
29
|
+
self.level = level or common_config.level or Const.LEVEL_L1
|
|
30
30
|
self.enable_dataloader = common_config.enable_dataloader
|
|
31
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
32
32
|
self.list = task_config.list if task_config.list else []
|
|
@@ -34,10 +34,7 @@ class DebuggerConfig:
|
|
|
34
34
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
35
35
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
36
36
|
self.framework = Const.PT_FRAMEWORK
|
|
37
|
-
|
|
38
|
-
if self.level == Const.LEVEL_L2:
|
|
39
|
-
self.is_backward_kernel_dump = False
|
|
40
|
-
self._check_and_adjust_config_with_l2()
|
|
37
|
+
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
41
38
|
|
|
42
39
|
if self.task == Const.FREE_BENCHMARK:
|
|
43
40
|
self.fuzz_device = task_config.fuzz_device
|
|
@@ -64,6 +61,10 @@ class DebuggerConfig:
|
|
|
64
61
|
|
|
65
62
|
self.check()
|
|
66
63
|
|
|
64
|
+
if self.level == Const.LEVEL_L2:
|
|
65
|
+
self.is_backward_kernel_dump = False
|
|
66
|
+
self._check_and_adjust_config_with_l2()
|
|
67
|
+
|
|
67
68
|
def check_kwargs(self):
|
|
68
69
|
if self.task and self.task not in Const.TASK_LIST:
|
|
69
70
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
@@ -74,29 +75,53 @@ class DebuggerConfig:
|
|
|
74
75
|
if not self.dump_path:
|
|
75
76
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
76
77
|
f"The dump_path not found.")
|
|
78
|
+
if not isinstance(self.async_dump, bool):
|
|
79
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
|
+
f"The parameters async_dump should be bool.")
|
|
81
|
+
if self.async_dump and self.task == Const.TENSOR and not self.list:
|
|
82
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
83
|
+
f"The parameters async_dump is true in tensor task, the parameters list cannot be "
|
|
84
|
+
f"empty.")
|
|
85
|
+
if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
86
|
+
logger.warning_on_rank_0(
|
|
87
|
+
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
88
|
+
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
89
|
+
)
|
|
90
|
+
self.level = Const.LEVEL_MIX
|
|
77
91
|
|
|
78
92
|
def check(self):
|
|
79
93
|
self.check_kwargs()
|
|
80
94
|
return True
|
|
81
95
|
|
|
82
96
|
def check_model(self, instance, start_model):
|
|
83
|
-
if self.level not in [
|
|
97
|
+
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
84
98
|
if instance.model is not None or start_model is not None:
|
|
85
|
-
logger.
|
|
99
|
+
logger.info_on_rank_0(
|
|
86
100
|
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
87
101
|
return
|
|
88
|
-
if start_model is None:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
102
|
+
if start_model is None and instance.model is None:
|
|
103
|
+
logger.error_on_rank_0(
|
|
104
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
105
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
106
|
+
|
|
107
|
+
instance.model = start_model if start_model is not None else instance.model
|
|
108
|
+
if isinstance(instance.model, torch.nn.Module):
|
|
93
109
|
return
|
|
94
|
-
|
|
95
|
-
|
|
110
|
+
|
|
111
|
+
error_model = None
|
|
112
|
+
if isinstance(instance.model, (list, tuple)):
|
|
113
|
+
for model in instance.model:
|
|
114
|
+
if not isinstance(model, torch.nn.Module):
|
|
115
|
+
error_model = model
|
|
116
|
+
break
|
|
96
117
|
else:
|
|
97
|
-
|
|
118
|
+
error_model = instance.model
|
|
119
|
+
|
|
120
|
+
if error_model is not None:
|
|
121
|
+
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
122
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
98
123
|
raise MsprobeException(
|
|
99
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
124
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
100
125
|
|
|
101
126
|
def _check_and_adjust_config_with_l2(self):
|
|
102
127
|
if self.scope:
|