mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from msprobe.core.common.runtime import Runtime
|
|
18
|
+
from msprobe.core.common.utils import Const
|
|
19
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ATTLManager:
|
|
24
|
+
def __init__(self, config):
|
|
25
|
+
self.config = config
|
|
26
|
+
self.attl = None
|
|
27
|
+
|
|
28
|
+
def attl_init(self):
|
|
29
|
+
if self.config.online_run_ut:
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
|
|
31
|
+
attl_config = ATTLConfig(is_benchmark_device=False,
|
|
32
|
+
connect_ip=self.config.host,
|
|
33
|
+
connect_port=self.config.port,
|
|
34
|
+
nfs_path=self.config.nfs_path,
|
|
35
|
+
tls_path=self.config.tls_path)
|
|
36
|
+
need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
|
|
37
|
+
self.attl = ATTL('npu', attl_config, need_dump=need_dump)
|
|
38
|
+
if self.config.nfs_path:
|
|
39
|
+
self.attl.upload("start")
|
|
40
|
+
|
|
41
|
+
def attl_send(self, name, args, kwargs, output):
|
|
42
|
+
api_data = ApiData(
|
|
43
|
+
name[:-len(Const.FORWARD_NAME_SUFFIX)],
|
|
44
|
+
args,
|
|
45
|
+
kwargs,
|
|
46
|
+
output,
|
|
47
|
+
Runtime.current_iter,
|
|
48
|
+
Runtime.current_rank
|
|
49
|
+
)
|
|
50
|
+
logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
|
|
51
|
+
api_type, _, _ = api_data.name.split(Const.SEP)
|
|
52
|
+
if api_type in [Const.DISTRIBUTED]:
|
|
53
|
+
logger.info(f"api {api_data.name} is not supported, skip")
|
|
54
|
+
return
|
|
55
|
+
if self.config.nfs_path:
|
|
56
|
+
self.attl.upload(api_data)
|
|
57
|
+
else:
|
|
58
|
+
self.attl.send(api_data)
|
|
59
|
+
|
|
60
|
+
def attl_stop(self):
|
|
61
|
+
if self.config.nfs_path:
|
|
62
|
+
self.attl.upload("end")
|
|
63
|
+
elif self.attl.socket_manager is not None:
|
|
64
|
+
logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
|
|
65
|
+
self.attl.socket_manager.send_stop_signal()
|
|
@@ -29,6 +29,8 @@ def softmax_func(x, axis=None):
|
|
|
29
29
|
|
|
30
30
|
def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
31
31
|
input_dtype = x.dtype
|
|
32
|
+
if x.dim() < 1:
|
|
33
|
+
raise ValueError("Input x must have at least 1 dimensions.")
|
|
32
34
|
num_expert = x.shape[-1]
|
|
33
35
|
softmax = softmax_func(x, -1)
|
|
34
36
|
softmax = softmax.to(input_dtype)
|
|
@@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
|
|
|
36
38
|
expert_idx = expert_idx[:, :k]
|
|
37
39
|
y = torch.gather(softmax, index=expert_idx, dim=-1)
|
|
38
40
|
if finished_optional is not None:
|
|
41
|
+
if finished_optional.dim() < 1:
|
|
42
|
+
raise ValueError("Finished_optional must have at least 1 dimensions.")
|
|
39
43
|
finished_optional = finished_optional.view(finished_optional.shape[0], 1)
|
|
40
44
|
finished_optional = finished_optional.expand(-1, k)
|
|
41
45
|
expert_idx = torch.where(finished_optional, num_expert, expert_idx)
|
|
46
|
+
if y.dim() < 2:
|
|
47
|
+
raise ValueError("Variable y must have at least 2 dimensions.")
|
|
42
48
|
row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
|
|
43
49
|
|
|
44
50
|
return y, expert_idx, row_idx
|
|
@@ -117,6 +117,12 @@ def fusion_attention_forward(forward_params):
|
|
|
117
117
|
pse = forward_params.pse
|
|
118
118
|
scale = forward_params.scale
|
|
119
119
|
keep_prob = forward_params.keep_prob
|
|
120
|
+
|
|
121
|
+
# 除零风险拦截:keep_prob 为 0 时会导致除零错误
|
|
122
|
+
if keep_prob == 0:
|
|
123
|
+
raise ValueError("fusion_attention_forward: keep_prob cannot be zero to avoid division by zero.")
|
|
124
|
+
|
|
125
|
+
|
|
120
126
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
121
127
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
122
128
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -137,6 +143,11 @@ def fusion_attention_backward(backward_params):
|
|
|
137
143
|
pse = backward_params.pse
|
|
138
144
|
scale = backward_params.scale
|
|
139
145
|
keep_prob = backward_params.keep_prob
|
|
146
|
+
|
|
147
|
+
# 除零风险拦截:keep_prob 为 0 时会导致除零错误
|
|
148
|
+
if keep_prob == 0:
|
|
149
|
+
raise ValueError("fusion_attention_backward: keep_prob cannot be zero to avoid division by zero.")
|
|
150
|
+
|
|
140
151
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
141
152
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
142
153
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -164,23 +175,35 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
164
175
|
if input_layout == "BSH":
|
|
165
176
|
b, s1, h1 = query.shape
|
|
166
177
|
_, s2, h2 = key.shape
|
|
178
|
+
if n1 == 0:
|
|
179
|
+
raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
|
|
167
180
|
d = h1 // n1
|
|
181
|
+
if d == 0:
|
|
182
|
+
raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
|
|
168
183
|
n2 = h2 // d
|
|
169
184
|
elif input_layout == "SBH":
|
|
170
185
|
s1, b, h1 = query.shape
|
|
171
186
|
s2, _, h2 = key.shape
|
|
187
|
+
if n1 == 0:
|
|
188
|
+
raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
|
|
172
189
|
d = h1 // n1
|
|
190
|
+
if d == 0:
|
|
191
|
+
raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
|
|
173
192
|
n2 = h2 // d
|
|
174
193
|
elif input_layout == "BSND":
|
|
175
194
|
b, s1, n1, d = query.shape
|
|
176
195
|
_, s2, n2, _ = key.shape
|
|
177
196
|
h1 = n1 * d
|
|
178
197
|
h2 = n2 * d
|
|
198
|
+
if d == 0:
|
|
199
|
+
raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
|
|
179
200
|
elif input_layout == "BNSD":
|
|
180
201
|
b, n1, s1, d = query.shape
|
|
181
202
|
_, n2, s2, _ = key.shape
|
|
182
203
|
h1 = n1 * d
|
|
183
204
|
h2 = n2 * d
|
|
205
|
+
if d == 0:
|
|
206
|
+
raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
|
|
184
207
|
except Exception as e:
|
|
185
208
|
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
186
209
|
|
|
@@ -446,6 +469,8 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
|
446
469
|
input_layout = get_input_layout(*args, **kwargs)
|
|
447
470
|
|
|
448
471
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
472
|
+
if d == 0:
|
|
473
|
+
raise ValueError("npu_fusion_attention_forward_patch: head dimension (d) is zero, division by zero risk.")
|
|
449
474
|
if n1 == n2 and s1 == s2:
|
|
450
475
|
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
451
476
|
else:
|
|
@@ -478,6 +503,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
478
503
|
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
479
504
|
|
|
480
505
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
506
|
+
if d == 0:
|
|
507
|
+
raise ValueError("npu_fusion_attention_backward_patch: head dimension (d) is zero, division by zero risk.")
|
|
481
508
|
if n1 == n2 and s1 == s2:
|
|
482
509
|
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
483
510
|
else:
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -24,11 +24,12 @@ from functools import wraps
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
26
26
|
import torch.distributed as dist
|
|
27
|
+
|
|
27
28
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
28
29
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
29
30
|
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
30
31
|
from msprobe.core.common.log import logger
|
|
31
|
-
from msprobe.core.common.utils import check_seed_all
|
|
32
|
+
from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
|
|
32
33
|
from packaging import version
|
|
33
34
|
|
|
34
35
|
try:
|
|
@@ -38,7 +39,9 @@ except ImportError:
|
|
|
38
39
|
else:
|
|
39
40
|
is_gpu = False
|
|
40
41
|
|
|
42
|
+
|
|
41
43
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
44
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
42
45
|
|
|
43
46
|
if not is_gpu and not torch_without_guard_version:
|
|
44
47
|
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
@@ -57,7 +60,7 @@ def parameter_adapter(func):
|
|
|
57
60
|
|
|
58
61
|
@wraps(func)
|
|
59
62
|
def inner(self, *args, **kwargs):
|
|
60
|
-
if self.
|
|
63
|
+
if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
|
|
61
64
|
input_tensor = args[0]
|
|
62
65
|
indices = args[1]
|
|
63
66
|
if indices.dtype == torch.uint8:
|
|
@@ -77,7 +80,7 @@ def parameter_adapter(func):
|
|
|
77
80
|
else:
|
|
78
81
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
79
82
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
80
|
-
if self.
|
|
83
|
+
if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
|
|
81
84
|
return False
|
|
82
85
|
return func(self, *args, **kwargs)
|
|
83
86
|
|
|
@@ -261,6 +264,10 @@ class Const:
|
|
|
261
264
|
NPU = 'NPU'
|
|
262
265
|
DISTRIBUTED = 'Distributed'
|
|
263
266
|
|
|
267
|
+
HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
|
|
268
|
+
FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
|
|
269
|
+
FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
|
|
270
|
+
|
|
264
271
|
RAISE_PRECISION = {
|
|
265
272
|
torch.float16: torch.float32,
|
|
266
273
|
torch.bfloat16: torch.float32,
|
|
@@ -309,14 +316,14 @@ def print_rank_0(message):
|
|
|
309
316
|
logger.info(message)
|
|
310
317
|
|
|
311
318
|
|
|
312
|
-
def load_pt(pt_path, to_cpu=False):
|
|
319
|
+
def load_pt(pt_path, to_cpu=False, weights_only=True):
|
|
313
320
|
pt_path = os.path.realpath(pt_path)
|
|
314
321
|
check_file_or_directory_path(pt_path)
|
|
315
322
|
try:
|
|
316
323
|
if to_cpu:
|
|
317
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=
|
|
324
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only)
|
|
318
325
|
else:
|
|
319
|
-
pt = torch.load(pt_path, weights_only=
|
|
326
|
+
pt = torch.load(pt_path, weights_only=weights_only)
|
|
320
327
|
except Exception as e:
|
|
321
328
|
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
322
329
|
return pt
|
|
@@ -391,7 +398,7 @@ def save_api_data(api_data):
|
|
|
391
398
|
io_buff = io.BytesIO()
|
|
392
399
|
torch.save(api_data, io_buff)
|
|
393
400
|
except Exception as e:
|
|
394
|
-
raise RuntimeError(
|
|
401
|
+
raise RuntimeError("save api_data to io_buff failed") from e
|
|
395
402
|
return io_buff
|
|
396
403
|
|
|
397
404
|
|
|
@@ -401,7 +408,7 @@ def load_api_data(api_data_bytes):
|
|
|
401
408
|
buffer = io.BytesIO(api_data_bytes)
|
|
402
409
|
buffer = torch.load(buffer, map_location="cpu")
|
|
403
410
|
except Exception as e:
|
|
404
|
-
raise RuntimeError(
|
|
411
|
+
raise RuntimeError("load api_data from bytes failed") from e
|
|
405
412
|
return buffer
|
|
406
413
|
|
|
407
414
|
|
|
@@ -419,7 +426,11 @@ def is_recomputation():
|
|
|
419
426
|
bool: True if in the re-computation phase, False otherwise.
|
|
420
427
|
"""
|
|
421
428
|
backward_function_indices = []
|
|
422
|
-
|
|
429
|
+
try:
|
|
430
|
+
call_stack = inspect.stack()
|
|
431
|
+
except Exception as e:
|
|
432
|
+
logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
|
|
433
|
+
return False
|
|
423
434
|
|
|
424
435
|
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
425
436
|
for frame_info in call_stack:
|
|
@@ -449,9 +460,11 @@ def is_recomputation():
|
|
|
449
460
|
|
|
450
461
|
def check_save_param(variable, name, save_backward):
|
|
451
462
|
# try catch this api to skip invalid call
|
|
452
|
-
|
|
463
|
+
valid_data_types = (torch.Tensor, int, float, str)
|
|
464
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
465
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
453
466
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
454
|
-
"should be one of
|
|
467
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
455
468
|
"Skip current save process.")
|
|
456
469
|
raise ValueError
|
|
457
470
|
if not isinstance(name, str):
|
|
@@ -466,10 +479,31 @@ def check_save_param(variable, name, save_backward):
|
|
|
466
479
|
raise ValueError
|
|
467
480
|
|
|
468
481
|
|
|
469
|
-
def
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
482
|
+
def is_torch_nn_module(variable):
|
|
483
|
+
return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def is_hifloat8_tensor(tensor):
|
|
487
|
+
if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
|
|
488
|
+
return True
|
|
489
|
+
return False
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def is_float8_tensor(tensor):
|
|
493
|
+
if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
|
|
494
|
+
return True
|
|
495
|
+
return is_hifloat8_tensor(tensor)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def register_forward_pre_hook(module, forward_pre_hook):
|
|
499
|
+
if torch_version_above_or_equal_2:
|
|
500
|
+
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
501
|
+
else:
|
|
502
|
+
module.register_forward_pre_hook(forward_pre_hook)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def register_forward_hook(module, forward_hook):
|
|
506
|
+
if torch_version_above_or_equal_2:
|
|
507
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
508
|
+
else:
|
|
509
|
+
module.register_forward_hook(forward_hook)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c)
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,41 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
17
|
-
|
|
18
|
-
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
-
from msprobe.core.common.file_utils import create_directory
|
|
20
|
-
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
21
|
-
set_dump_path
|
|
22
|
-
from msprobe.core.compare.acc_compare import ModeConfig
|
|
23
|
-
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
|
|
24
|
-
from msprobe.pytorch.common.log import logger
|
|
25
|
-
from msprobe.pytorch.compare.pt_compare import PTComparator, compare
|
|
16
|
+
from msprobe.core.compare.utils import compare_distributed_inner
|
|
17
|
+
from msprobe.pytorch.compare.pt_compare import compare
|
|
26
18
|
|
|
27
19
|
|
|
28
20
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
29
|
-
|
|
30
|
-
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
31
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
32
|
-
is_print_compare_log = kwargs.get("is_print_compare_log", True)
|
|
33
|
-
# get the ranks and match by order
|
|
34
|
-
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
35
|
-
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
36
|
-
if len(npu_ranks) != len(bench_ranks):
|
|
37
|
-
logger.error(
|
|
38
|
-
"The number of ranks in the two runs are different. "
|
|
39
|
-
"Unable to match the ranks. "
|
|
40
|
-
"Please use another folder to compare or use compare() api and manually match the ranks.")
|
|
41
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
|
-
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
|
-
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
44
|
-
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
45
|
-
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
46
|
-
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
47
|
-
|
|
48
|
-
dump_result_param = {
|
|
49
|
-
"npu_json_path": npu_path,
|
|
50
|
-
"bench_json_path": bench_path,
|
|
51
|
-
"is_print_compare_log": is_print_compare_log
|
|
52
|
-
}
|
|
53
|
-
compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
21
|
+
compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare, **kwargs)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,92 +13,21 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
16
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
|
|
17
|
+
from msprobe.pytorch.compare.utils import read_pt_data
|
|
17
18
|
|
|
18
|
-
import torch
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
set_dump_path
|
|
25
|
-
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
26
|
-
from msprobe.core.compare.utils import set_stack_json_path
|
|
27
|
-
from msprobe.pytorch.common.log import logger
|
|
28
|
-
from msprobe.pytorch.common.utils import load_pt
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class PTComparator(Comparator):
|
|
32
|
-
def __init__(self, mode_config, data_mapping=None):
|
|
33
|
-
super().__init__(mode_config)
|
|
34
|
-
|
|
35
|
-
self.stack_mode = mode_config.stack_mode
|
|
36
|
-
self.auto_analyze = mode_config.auto_analyze
|
|
37
|
-
self.fuzzy_match = mode_config.fuzzy_match
|
|
38
|
-
self.dump_mode = mode_config.dump_mode
|
|
39
|
-
|
|
40
|
-
self.frame_name = PTComparator.__name__
|
|
41
|
-
self.data_mapping = data_mapping
|
|
42
|
-
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
43
|
-
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
44
|
-
elif isinstance(self.data_mapping, dict):
|
|
45
|
-
self.data_mapping_dict = self.data_mapping
|
|
46
|
-
else:
|
|
47
|
-
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
48
|
-
f"{type(self.data_mapping)}")
|
|
49
|
-
|
|
50
|
-
@staticmethod
|
|
51
|
-
def load_mapping_file(mapping_file):
|
|
52
|
-
if isinstance(mapping_file, str):
|
|
53
|
-
mapping_dict = load_yaml(mapping_file)
|
|
54
|
-
else:
|
|
55
|
-
mapping_dict = {}
|
|
56
|
-
return mapping_dict
|
|
57
|
-
|
|
58
|
-
def read_npy_data(self, dir_path, file_name):
|
|
59
|
-
if not file_name:
|
|
60
|
-
return None
|
|
61
|
-
data_path = os.path.join(dir_path, file_name)
|
|
62
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
63
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
64
|
-
data_path = path_checker.common_check()
|
|
65
|
-
try:
|
|
66
|
-
# detach because numpy can not process gradient information
|
|
67
|
-
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
68
|
-
except RuntimeError as e:
|
|
69
|
-
# 这里捕获 load_pt 中抛出的异常
|
|
70
|
-
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
71
|
-
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
72
|
-
except AttributeError as e:
|
|
73
|
-
# 这里捕获 detach 方法抛出的异常
|
|
74
|
-
logger.error(f"Failed to detach the loaded tensor.")
|
|
75
|
-
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
76
|
-
if data_value.dtype == torch.bfloat16:
|
|
77
|
-
data_value = data_value.to(torch.float32)
|
|
78
|
-
data_value = data_value.numpy()
|
|
79
|
-
return data_value
|
|
20
|
+
def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple:
|
|
21
|
+
n_value = read_pt_data(npu_dir, npu_data_name)
|
|
22
|
+
b_value = read_pt_data(bench_dir, bench_data_name)
|
|
23
|
+
return n_value, b_value
|
|
80
24
|
|
|
81
25
|
|
|
82
26
|
def compare(input_param, output_path, **kwargs):
|
|
83
|
-
|
|
84
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
85
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
86
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
87
|
-
suffix = kwargs.get('suffix', '')
|
|
88
|
-
|
|
89
|
-
set_dump_path(input_param)
|
|
90
|
-
dump_mode = get_dump_mode(input_param)
|
|
91
|
-
if "stack_json_path" in input_param:
|
|
92
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
93
|
-
else:
|
|
94
|
-
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
95
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
96
|
-
create_directory(output_path)
|
|
97
|
-
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
98
|
-
except (CompareException, FileCheckException) as error:
|
|
99
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
100
|
-
raise CompareException(error.code) from error
|
|
27
|
+
config = setup_comparison(input_param, output_path, **kwargs)
|
|
101
28
|
|
|
102
|
-
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match,
|
|
103
|
-
|
|
104
|
-
|
|
29
|
+
mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
30
|
+
config.dump_mode, config.compared_file_type)
|
|
31
|
+
mapping_config = MappingConfig(data_mapping=config.data_mapping)
|
|
32
|
+
pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
|
|
33
|
+
pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.utils import logger, CompareException
|
|
21
|
+
from msprobe.core.common.file_utils import FileChecker, FileCheckConst
|
|
22
|
+
from msprobe.pytorch.common.utils import load_pt
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def read_pt_data(dir_path, file_name):
|
|
26
|
+
if not file_name:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
data_path = os.path.join(dir_path, file_name)
|
|
30
|
+
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
31
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
32
|
+
data_path = path_checker.common_check()
|
|
33
|
+
try:
|
|
34
|
+
# detach because numpy can not process gradient information
|
|
35
|
+
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
36
|
+
except RuntimeError as e:
|
|
37
|
+
# 这里捕获 load_pt 中抛出的异常
|
|
38
|
+
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
39
|
+
raise CompareException(CompareException.INVALID_FILE_ERROR) from e
|
|
40
|
+
except AttributeError as e:
|
|
41
|
+
# 这里捕获 detach 方法抛出的异常
|
|
42
|
+
logger.error(f"Failed to detach the loaded tensor.")
|
|
43
|
+
raise CompareException(CompareException.DETACH_ERROR) from e
|
|
44
|
+
if data_value.dtype == torch.bfloat16:
|
|
45
|
+
data_value = data_value.to(torch.float32)
|
|
46
|
+
data_value = data_value.numpy()
|
|
47
|
+
return data_value
|
|
@@ -13,11 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
16
|
from msprobe.core.common.const import Const
|
|
19
17
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
18
|
from msprobe.pytorch.common.log import logger
|
|
19
|
+
from msprobe.pytorch.common.utils import is_torch_nn_module
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
class DebuggerConfig:
|
|
@@ -60,6 +59,7 @@ class DebuggerConfig:
|
|
|
60
59
|
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
61
60
|
|
|
62
61
|
self.check()
|
|
62
|
+
self._check_statistics_config(task_config)
|
|
63
63
|
|
|
64
64
|
if self.level == Const.LEVEL_L2:
|
|
65
65
|
self.is_backward_kernel_dump = False
|
|
@@ -78,10 +78,13 @@ class DebuggerConfig:
|
|
|
78
78
|
if not isinstance(self.async_dump, bool):
|
|
79
79
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
80
|
f"The parameters async_dump should be bool.")
|
|
81
|
-
if self.async_dump and self.task == Const.TENSOR
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
81
|
+
if self.async_dump and self.task == Const.TENSOR:
|
|
82
|
+
if self.level == Const.LEVEL_DEBUG:
|
|
83
|
+
self.list = [] # async_dump + debug level case ignore list
|
|
84
|
+
if not self.list and self.level != Const.LEVEL_DEBUG:
|
|
85
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
86
|
+
f"The parameters async_dump is true in tensor task, the parameters list cannot be "
|
|
87
|
+
f"empty.")
|
|
85
88
|
if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
86
89
|
logger.warning_on_rank_0(
|
|
87
90
|
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
@@ -93,25 +96,24 @@ class DebuggerConfig:
|
|
|
93
96
|
self.check_kwargs()
|
|
94
97
|
return True
|
|
95
98
|
|
|
96
|
-
def check_model(self, instance, start_model):
|
|
97
|
-
if
|
|
98
|
-
|
|
99
|
-
logger.info_on_rank_0(
|
|
100
|
-
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
99
|
+
def check_model(self, instance, start_model, token_range=None):
|
|
100
|
+
instance.model = start_model if start_model is not None else instance.model
|
|
101
|
+
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
|
|
101
102
|
return
|
|
102
|
-
|
|
103
|
+
|
|
104
|
+
if instance.model is None:
|
|
103
105
|
logger.error_on_rank_0(
|
|
104
|
-
f"For level {self.level}
|
|
106
|
+
f"For level {self.level} or non-empty token_range, "
|
|
107
|
+
f"PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
105
108
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
106
109
|
|
|
107
|
-
|
|
108
|
-
if isinstance(instance.model, torch.nn.Module):
|
|
110
|
+
if is_torch_nn_module(instance.model):
|
|
109
111
|
return
|
|
110
112
|
|
|
111
113
|
error_model = None
|
|
112
114
|
if isinstance(instance.model, (list, tuple)):
|
|
113
115
|
for model in instance.model:
|
|
114
|
-
if not
|
|
116
|
+
if not is_torch_nn_module(model):
|
|
115
117
|
error_model = model
|
|
116
118
|
break
|
|
117
119
|
else:
|
|
@@ -119,7 +121,7 @@ class DebuggerConfig:
|
|
|
119
121
|
|
|
120
122
|
if error_model is not None:
|
|
121
123
|
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
122
|
-
f"type, currently there is
|
|
124
|
+
f"type, currently there is an unsupported {type(error_model)} type.")
|
|
123
125
|
raise MsprobeException(
|
|
124
126
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
125
127
|
|
|
@@ -130,8 +132,23 @@ class DebuggerConfig:
|
|
|
130
132
|
if not self.list or len(self.list) != 1:
|
|
131
133
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
132
134
|
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
135
|
+
if self.task != Const.TENSOR:
|
|
136
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
137
|
+
f"When level is set to L2, the task must be set to tensor.")
|
|
138
|
+
|
|
133
139
|
api_name = self.list[0]
|
|
134
140
|
if api_name.endswith(Const.BACKWARD):
|
|
135
141
|
self.is_backward_kernel_dump = True
|
|
136
142
|
api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
|
|
137
143
|
self.list.append(api_forward_name)
|
|
144
|
+
|
|
145
|
+
def _check_statistics_config(self, task_config):
|
|
146
|
+
if self.task != Const.STATISTICS:
|
|
147
|
+
return
|
|
148
|
+
self.tensor_list = []
|
|
149
|
+
if not hasattr(task_config, "tensor_list"):
|
|
150
|
+
return
|
|
151
|
+
if self.level == Const.LEVEL_DEBUG and task_config.tensor_list:
|
|
152
|
+
logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
|
|
153
|
+
return
|
|
154
|
+
self.tensor_list = task_config.tensor_list
|