mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,7 +12,6 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from mindspore import Tensor, ops, mint
|
|
17
17
|
from mindspore.mint.nn import functional
|
|
@@ -20,8 +20,21 @@ from mindspore.communication import comm_func
|
|
|
20
20
|
|
|
21
21
|
from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
|
|
22
22
|
HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
|
|
23
|
-
|
|
23
|
+
HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP,
|
|
24
|
+
HOOKTorchDistributedOP, HOOKTorchNpuOP,
|
|
25
|
+
get_wrap_api_list, get_wrap_torch_api_list, setup_hooks)
|
|
24
26
|
from msprobe.core.common.utils import Const
|
|
27
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
28
|
+
|
|
29
|
+
if is_mindtorch():
|
|
30
|
+
import torch
|
|
31
|
+
import torch_npu
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def stub_method(method):
|
|
35
|
+
def wrapped_method(*args, **kwargs):
|
|
36
|
+
return method(*args, **kwargs)
|
|
37
|
+
return wrapped_method
|
|
25
38
|
|
|
26
39
|
|
|
27
40
|
class ApiRegistry:
|
|
@@ -34,6 +47,12 @@ class ApiRegistry:
|
|
|
34
47
|
self.distributed_ori_attr = {}
|
|
35
48
|
self.norm_inner_ops_ori_attr = {}
|
|
36
49
|
|
|
50
|
+
self.torch_ori_attr = {}
|
|
51
|
+
self.torch_tensor_ori_attr = {}
|
|
52
|
+
self.torch_functional_ori_attr = {}
|
|
53
|
+
self.torch_distributed_ori_attr = {}
|
|
54
|
+
self.torch_npu_ori_attr = {}
|
|
55
|
+
|
|
37
56
|
self.tensor_hook_attr = {}
|
|
38
57
|
self.stub_tensor_hook_attr = {}
|
|
39
58
|
self.functional_hook_attr = {}
|
|
@@ -42,6 +61,12 @@ class ApiRegistry:
|
|
|
42
61
|
self.distibuted_hook_attr = {}
|
|
43
62
|
self.norm_inner_ops_hook_attr = {}
|
|
44
63
|
|
|
64
|
+
self.torch_hook_attr = {}
|
|
65
|
+
self.torch_tensor_hook_attr = {}
|
|
66
|
+
self.torch_functional_hook_attr = {}
|
|
67
|
+
self.torch_distributed_hook_attr = {}
|
|
68
|
+
self.torch_npu_hook_attr = {}
|
|
69
|
+
|
|
45
70
|
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
|
|
46
71
|
|
|
47
72
|
@staticmethod
|
|
@@ -50,9 +75,13 @@ class ApiRegistry:
|
|
|
50
75
|
if Const.SEP in api:
|
|
51
76
|
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
52
77
|
sub_module = getattr(ori_api_group, sub_module_name)
|
|
53
|
-
|
|
78
|
+
ori_api_func = getattr(sub_module, sub_op)
|
|
54
79
|
else:
|
|
55
|
-
|
|
80
|
+
ori_api_func = getattr(ori_api_group, api)
|
|
81
|
+
if ori_api_group == StubTensor:
|
|
82
|
+
api_ori_attr[api] = stub_method(ori_api_func)
|
|
83
|
+
continue
|
|
84
|
+
api_ori_attr[api] = ori_api_func
|
|
56
85
|
|
|
57
86
|
@staticmethod
|
|
58
87
|
def set_api_attr(api_group, attr_dict):
|
|
@@ -72,22 +101,71 @@ class ApiRegistry:
|
|
|
72
101
|
self.set_api_attr(ops, self.norm_inner_ops_ori_attr)
|
|
73
102
|
|
|
74
103
|
def api_set_hook_func(self):
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
104
|
+
if is_mindtorch():
|
|
105
|
+
self.set_api_attr(torch, self.torch_hook_attr)
|
|
106
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr)
|
|
107
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr)
|
|
108
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr)
|
|
109
|
+
self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
|
|
110
|
+
else:
|
|
111
|
+
self.set_api_attr(Tensor, self.tensor_hook_attr)
|
|
112
|
+
self.set_api_attr(StubTensor, self.stub_tensor_hook_attr)
|
|
113
|
+
self.set_api_attr(ops, self.functional_hook_attr)
|
|
114
|
+
self.set_api_attr(mint, self.mint_ops_hook_attr)
|
|
115
|
+
self.set_api_attr(functional, self.mint_func_ops_hook_attr)
|
|
116
|
+
self.set_api_attr(comm_func, self.distibuted_hook_attr)
|
|
81
117
|
|
|
82
118
|
def api_set_ori_func(self):
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
119
|
+
if is_mindtorch():
|
|
120
|
+
self.set_api_attr(torch, self.torch_ori_attr)
|
|
121
|
+
self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr)
|
|
122
|
+
self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr)
|
|
123
|
+
self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr)
|
|
124
|
+
self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
|
|
125
|
+
else:
|
|
126
|
+
self.set_api_attr(Tensor, self.tensor_ori_attr)
|
|
127
|
+
self.set_api_attr(StubTensor, self.stub_tensor_ori_attr)
|
|
128
|
+
self.set_api_attr(ops, self.functional_ori_attr)
|
|
129
|
+
self.set_api_attr(mint, self.mint_ops_ori_attr)
|
|
130
|
+
self.set_api_attr(functional, self.mint_func_ops_ori_attr)
|
|
131
|
+
self.set_api_attr(comm_func, self.distributed_ori_attr)
|
|
89
132
|
|
|
90
133
|
def initialize_hook(self, hook):
|
|
134
|
+
setup_hooks(hook)
|
|
135
|
+
if is_mindtorch():
|
|
136
|
+
wrap_torch_api_name = get_wrap_torch_api_list()
|
|
137
|
+
self.store_ori_attr(torch,
|
|
138
|
+
wrap_torch_api_name.torch_api_names, self.torch_ori_attr)
|
|
139
|
+
self.store_ori_attr(torch.Tensor,
|
|
140
|
+
wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr)
|
|
141
|
+
self.store_ori_attr(torch.nn.functional,
|
|
142
|
+
wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr)
|
|
143
|
+
self.store_ori_attr(torch.distributed,
|
|
144
|
+
wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr)
|
|
145
|
+
self.store_ori_attr(torch_npu,
|
|
146
|
+
wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr)
|
|
147
|
+
for attr_name in dir(HOOKTorchOP):
|
|
148
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
149
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
150
|
+
self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name)
|
|
151
|
+
for attr_name in dir(HOOKTorchTensor):
|
|
152
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
153
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
154
|
+
self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name)
|
|
155
|
+
for attr_name in dir(HOOKTorchFunctionalOP):
|
|
156
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
157
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
158
|
+
self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name)
|
|
159
|
+
for attr_name in dir(HOOKTorchDistributedOP):
|
|
160
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
161
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
162
|
+
self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name)
|
|
163
|
+
for attr_name in dir(HOOKTorchNpuOP):
|
|
164
|
+
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
165
|
+
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
166
|
+
self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name)
|
|
167
|
+
return
|
|
168
|
+
|
|
91
169
|
wrap_api_name = get_wrap_api_list()
|
|
92
170
|
self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr)
|
|
93
171
|
self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr)
|
|
@@ -96,7 +174,6 @@ class ApiRegistry:
|
|
|
96
174
|
self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
|
|
97
175
|
self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
|
|
98
176
|
self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
|
|
99
|
-
setup_hooks(hook)
|
|
100
177
|
for attr_name in dir(HOOKTensor):
|
|
101
178
|
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
|
|
102
179
|
api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,45 +12,66 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
from collections import defaultdict
|
|
17
17
|
|
|
18
18
|
from mindspore import nn
|
|
19
19
|
|
|
20
|
-
from msprobe.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
cell_count
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
20
|
+
from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def add_cell_count(name):
|
|
24
|
+
HOOKCell.cell_count[name] += 1
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_cell_count(name):
|
|
28
|
+
return HOOKCell.cell_count[name]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __init__(self, build_hook) -> None:
|
|
32
|
+
super(HOOKCell, self).__init__()
|
|
33
|
+
self.changed_status = False
|
|
34
|
+
self.input_kwargs = {}
|
|
35
|
+
self.prefix = ""
|
|
36
|
+
if not HOOKCell.g_stop_hook:
|
|
37
|
+
HOOKCell.g_stop_hook = True
|
|
38
|
+
self.changed_status = True
|
|
39
|
+
if hasattr(self, "prefix_api_name"):
|
|
40
|
+
self.prefix = self.prefix_api_name
|
|
41
|
+
|
|
42
|
+
self.forward_data_collected = False
|
|
43
|
+
forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
|
|
44
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
45
|
+
self.register_forward_hook(forward_hook)
|
|
46
|
+
register_backward_hook_functions["full"](self, backward_hook)
|
|
47
|
+
register_backward_hook_functions["pre"](self, backward_pre_hook)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# 重载call,加全局标志。
|
|
51
|
+
def __call__(self, *args, **kwargs):
|
|
52
|
+
try:
|
|
53
|
+
self.input_kwargs = kwargs
|
|
54
|
+
out = super(HOOKCell, self).__call__(*args, **kwargs)
|
|
55
|
+
except Exception as e:
|
|
56
|
+
raise e
|
|
57
|
+
finally:
|
|
58
|
+
if self.changed_status:
|
|
59
|
+
self.changed_status = False
|
|
60
|
+
HOOKCell.g_stop_hook = False
|
|
61
|
+
return out
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
hook_cell_dict = {
|
|
65
|
+
"cell_count": defaultdict(int),
|
|
66
|
+
"g_stop_hook": False,
|
|
67
|
+
"add_cell_count": staticmethod(add_cell_count),
|
|
68
|
+
"get_cell_count": staticmethod(get_cell_count),
|
|
69
|
+
"__init__": __init__,
|
|
70
|
+
"__call__": __call__
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if is_mindtorch():
|
|
74
|
+
import torch
|
|
75
|
+
HOOKCell = type("HOOKCell", (torch.nn.Module,), hook_cell_dict)
|
|
76
|
+
else:
|
|
77
|
+
HOOKCell = type("HOOKCell", (nn.Cell,), hook_cell_dict)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,18 +12,16 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
|
|
18
|
-
import mindspore as ms
|
|
19
|
-
from mindspore.common.tensor import Tensor
|
|
20
18
|
from mindspore import ops
|
|
19
|
+
from mindspore.common.tensor import Tensor
|
|
21
20
|
|
|
22
|
-
from msprobe.mindspore.common.log import logger
|
|
23
21
|
from msprobe.core.common.utils import Const, DumpException
|
|
24
|
-
from msprobe.core.data_dump.data_processor.base import
|
|
25
|
-
|
|
22
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
|
|
23
|
+
ModuleForwardInputsOutputs)
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
class PrimitiveHookService:
|
|
@@ -41,6 +40,7 @@ class PrimitiveHookService:
|
|
|
41
40
|
Returns:
|
|
42
41
|
callable: 包装后的 primitive 函数。
|
|
43
42
|
"""
|
|
43
|
+
|
|
44
44
|
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
|
|
45
45
|
"""
|
|
46
46
|
创建反向 hook 函数,用于捕获梯度。
|
|
@@ -54,26 +54,24 @@ class PrimitiveHookService:
|
|
|
54
54
|
Returns:
|
|
55
55
|
callable: 反向 hook 函数。
|
|
56
56
|
"""
|
|
57
|
-
def backward_hook(grad):
|
|
58
57
|
|
|
59
|
-
|
|
58
|
+
def backward_hook(grad):
|
|
59
|
+
captured_grads.extend(grad)
|
|
60
60
|
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
61
61
|
|
|
62
62
|
try:
|
|
63
|
-
if
|
|
63
|
+
if hook_type == Const.INPUT:
|
|
64
64
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
65
65
|
new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
|
|
66
66
|
self.service_instance.data_collector.backward_output_data_collect(
|
|
67
67
|
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
68
68
|
)
|
|
69
|
-
|
|
70
|
-
elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
|
|
69
|
+
elif hook_type == Const.OUTPUT:
|
|
71
70
|
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
72
71
|
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
|
|
73
72
|
self.service_instance.data_collector.backward_input_data_collect(
|
|
74
73
|
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
75
74
|
)
|
|
76
|
-
captured_grads.clear()
|
|
77
75
|
|
|
78
76
|
except Exception as exception:
|
|
79
77
|
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
@@ -104,7 +102,7 @@ class PrimitiveHookService:
|
|
|
104
102
|
hooked_inputs.append(arg_hooked)
|
|
105
103
|
else:
|
|
106
104
|
hooked_inputs.append(arg)
|
|
107
|
-
return hooked_inputs
|
|
105
|
+
return tuple(hooked_inputs)
|
|
108
106
|
|
|
109
107
|
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
|
|
110
108
|
"""
|
|
@@ -137,6 +135,34 @@ class PrimitiveHookService:
|
|
|
137
135
|
return tuple(hooked_outputs)
|
|
138
136
|
return out
|
|
139
137
|
|
|
138
|
+
def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
|
|
139
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
140
|
+
try:
|
|
141
|
+
self.service_instance.data_collector.forward_input_data_collect(
|
|
142
|
+
primitive_name,
|
|
143
|
+
primitive_instance,
|
|
144
|
+
os.getpid(),
|
|
145
|
+
module_input_output
|
|
146
|
+
)
|
|
147
|
+
except Exception as exception:
|
|
148
|
+
logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
|
|
149
|
+
f"primitive_name: {primitive_name}")
|
|
150
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
151
|
+
|
|
152
|
+
def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
|
|
153
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
154
|
+
try:
|
|
155
|
+
self.service_instance.data_collector.forward_output_data_collect(
|
|
156
|
+
primitive_name,
|
|
157
|
+
primitive_instance,
|
|
158
|
+
os.getpid(),
|
|
159
|
+
module_input_output
|
|
160
|
+
)
|
|
161
|
+
except Exception as exception:
|
|
162
|
+
logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
|
|
163
|
+
f"primitive_name: {primitive_name}")
|
|
164
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
165
|
+
|
|
140
166
|
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
141
167
|
"""
|
|
142
168
|
包装后的 primitive 调用函数,添加输入和输出的 hook。
|
|
@@ -165,27 +191,17 @@ class PrimitiveHookService:
|
|
|
165
191
|
f"primitive_name: {primitive_name}")
|
|
166
192
|
raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
|
|
167
193
|
|
|
194
|
+
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
195
|
+
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
196
|
+
|
|
197
|
+
pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
|
|
168
198
|
try:
|
|
169
199
|
out = origin_func(*hooked_inputs, **kwargs)
|
|
170
200
|
except Exception as exception:
|
|
171
201
|
logger.error(f"This is a primitive op dump error during function call: {exception}, "
|
|
172
202
|
f"primitive_name: {primitive_name}")
|
|
173
203
|
raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
|
|
174
|
-
|
|
175
|
-
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
176
|
-
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
177
|
-
if self.service_instance.data_collector:
|
|
178
|
-
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
179
|
-
try:
|
|
180
|
-
self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
181
|
-
os.getpid(), module_input_output)
|
|
182
|
-
except Exception as exception:
|
|
183
|
-
logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
|
|
184
|
-
f"primitive_name: {primitive_name}")
|
|
185
|
-
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
186
|
-
|
|
187
|
-
if self.service_instance.data_collector.if_return_forward_new_output():
|
|
188
|
-
out = self.service_instance.data_collector.get_forward_new_output()
|
|
204
|
+
post_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs, out)
|
|
189
205
|
|
|
190
206
|
try:
|
|
191
207
|
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
@@ -203,4 +219,3 @@ class PrimitiveHookService:
|
|
|
203
219
|
self.primitive_counters[primitive_name] = 0
|
|
204
220
|
else:
|
|
205
221
|
self.primitive_counters[primitive_name] += 1
|
|
206
|
-
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
# List of ops that register hooks
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
ops:
|
|
20
20
|
- adaptive_avg_pool1d
|
|
21
21
|
- adaptive_avg_pool2d
|
|
@@ -85,6 +85,7 @@ ops:
|
|
|
85
85
|
- relu6
|
|
86
86
|
- celu
|
|
87
87
|
- rrelu
|
|
88
|
+
- rms_norm
|
|
88
89
|
- selu
|
|
89
90
|
- sigmoid
|
|
90
91
|
- silu
|
|
@@ -490,6 +491,31 @@ ops:
|
|
|
490
491
|
- scatter_update
|
|
491
492
|
- derivative
|
|
492
493
|
- jet
|
|
494
|
+
- row_stack
|
|
495
|
+
- gather
|
|
496
|
+
- arange
|
|
497
|
+
- cond
|
|
498
|
+
- slice_scatter
|
|
499
|
+
- clip_by_norm
|
|
500
|
+
- eps
|
|
501
|
+
- layer_norm
|
|
502
|
+
- cast
|
|
503
|
+
- numel
|
|
504
|
+
- permute
|
|
505
|
+
- select_scatter
|
|
506
|
+
- group_norm
|
|
507
|
+
- eq
|
|
508
|
+
- embedding
|
|
509
|
+
- ones_like
|
|
510
|
+
- zeros
|
|
511
|
+
- nanmean
|
|
512
|
+
- shape
|
|
513
|
+
- zeros_like
|
|
514
|
+
- ones
|
|
515
|
+
- diagonal_scatter
|
|
516
|
+
- vander
|
|
517
|
+
- is_nonzero
|
|
518
|
+
- rotary_position_embedding
|
|
493
519
|
|
|
494
520
|
tensor:
|
|
495
521
|
- __abs__
|
|
@@ -528,6 +554,7 @@ tensor:
|
|
|
528
554
|
- acos
|
|
529
555
|
- acosh
|
|
530
556
|
- add
|
|
557
|
+
- add_
|
|
531
558
|
- addbmm
|
|
532
559
|
- addcdiv
|
|
533
560
|
- addcmul
|
|
@@ -582,6 +609,7 @@ tensor:
|
|
|
582
609
|
- diff
|
|
583
610
|
- digamma
|
|
584
611
|
- div
|
|
612
|
+
- div_
|
|
585
613
|
- divide
|
|
586
614
|
- equal
|
|
587
615
|
- erf
|
|
@@ -714,6 +742,8 @@ tensor:
|
|
|
714
742
|
- square
|
|
715
743
|
- squeeze
|
|
716
744
|
- std
|
|
745
|
+
- sub
|
|
746
|
+
- sub_
|
|
717
747
|
- subtract
|
|
718
748
|
- subtract
|
|
719
749
|
- svd
|
|
@@ -958,6 +988,7 @@ mint.nn.functional:
|
|
|
958
988
|
- one_hot_ext
|
|
959
989
|
- pad
|
|
960
990
|
- relu
|
|
991
|
+
- relu_
|
|
961
992
|
- sigmoid
|
|
962
993
|
- silu
|
|
963
994
|
- softmax
|
|
@@ -992,3 +1023,7 @@ communication.comm_func:
|
|
|
992
1023
|
- broadcast
|
|
993
1024
|
- gather_into_tensor
|
|
994
1025
|
- scatter_tensor
|
|
1026
|
+
- send
|
|
1027
|
+
- recv
|
|
1028
|
+
- isend
|
|
1029
|
+
- irecv
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -23,10 +23,16 @@ from mindspore.mint.nn import functional
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.common.file_utils import load_yaml
|
|
25
25
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
26
27
|
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
27
28
|
|
|
29
|
+
if is_mindtorch():
|
|
30
|
+
import torch
|
|
31
|
+
import torch_npu
|
|
32
|
+
|
|
28
33
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
29
34
|
yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
|
|
35
|
+
torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
class HOOKTensor(object):
|
|
@@ -53,6 +59,26 @@ class HOOKDistributedOP(object):
|
|
|
53
59
|
pass
|
|
54
60
|
|
|
55
61
|
|
|
62
|
+
class HOOKTorchOP(object):
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class HOOKTorchTensor(object):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class HOOKTorchFunctionalOP(object):
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class HOOKTorchDistributedOP(object):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HOOKTorchNpuOP(object):
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
|
|
56
82
|
class ApiTemplate(HOOKCell):
|
|
57
83
|
def __init__(self, api_name, api_dict, prefix, hook):
|
|
58
84
|
self.api_name = api_name
|
|
@@ -60,7 +86,30 @@ class ApiTemplate(HOOKCell):
|
|
|
60
86
|
self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
61
87
|
super().__init__(hook)
|
|
62
88
|
|
|
89
|
+
@staticmethod
|
|
90
|
+
def async_to_sync(output):
|
|
91
|
+
# Fake handle, used to return after the CommHandle executes the wait method
|
|
92
|
+
fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
|
|
93
|
+
if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
|
|
94
|
+
output[1].wait()
|
|
95
|
+
output = (output[0], fake_handle)
|
|
96
|
+
elif hasattr(output, "wait"):
|
|
97
|
+
output.wait()
|
|
98
|
+
output = fake_handle
|
|
99
|
+
return output
|
|
100
|
+
|
|
63
101
|
def construct(self, *args, **kwargs):
|
|
102
|
+
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
103
|
+
return args[0] if args else kwargs.get(Const.INPUT)
|
|
104
|
+
|
|
105
|
+
output = self.api_func(*args, **kwargs)
|
|
106
|
+
|
|
107
|
+
if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
|
|
108
|
+
if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
|
|
109
|
+
output = self.async_to_sync(output)
|
|
110
|
+
return output
|
|
111
|
+
|
|
112
|
+
def forward(self, *args, **kwargs):
|
|
64
113
|
if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
|
|
65
114
|
return args[0] if args else kwargs.get(Const.INPUT)
|
|
66
115
|
return self.api_func(*args, **kwargs)
|
|
@@ -77,6 +126,15 @@ class WrapApiName:
|
|
|
77
126
|
self.distributed_api_names = distributed_api_names
|
|
78
127
|
|
|
79
128
|
|
|
129
|
+
class WrapTorchApiName:
|
|
130
|
+
def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
|
|
131
|
+
self.torch_api_names = torch_api_names
|
|
132
|
+
self.tensor_api_names = tensor_api_names
|
|
133
|
+
self.functional_api_names = functional_api_names
|
|
134
|
+
self.distributed_api_names = distributed_api_names
|
|
135
|
+
self.npu_api_names = npu_api_names
|
|
136
|
+
|
|
137
|
+
|
|
80
138
|
def get_wrap_api_list():
|
|
81
139
|
api_list = load_yaml(yaml_path)
|
|
82
140
|
tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
@@ -93,6 +151,21 @@ def get_wrap_api_list():
|
|
|
93
151
|
return wrap_api_name
|
|
94
152
|
|
|
95
153
|
|
|
154
|
+
def get_wrap_torch_api_list():
|
|
155
|
+
api_list = load_yaml(torch_yaml_path)
|
|
156
|
+
torch_api = api_list.get("torch")
|
|
157
|
+
tensor_api = api_list.get("tensor")
|
|
158
|
+
functional_api = api_list.get("functional")
|
|
159
|
+
distributed_api = api_list.get("distributed")
|
|
160
|
+
npu_api = api_list.get("torch_npu")
|
|
161
|
+
wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
|
|
162
|
+
set(tensor_api) & set(dir(torch.Tensor)),
|
|
163
|
+
set(functional_api) & set(dir(torch.nn.functional)),
|
|
164
|
+
set(distributed_api) & set(dir(torch.distributed)),
|
|
165
|
+
set(npu_api) & set(dir(torch_npu)))
|
|
166
|
+
return wrap_api_name
|
|
167
|
+
|
|
168
|
+
|
|
96
169
|
def wrap_api_func(api_name, api_dict, prefix, hook):
|
|
97
170
|
def api_function(*args, **kwargs):
|
|
98
171
|
return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
|
|
@@ -106,6 +179,24 @@ def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
|
|
|
106
179
|
|
|
107
180
|
|
|
108
181
|
def setup_hooks(hook):
|
|
182
|
+
if is_mindtorch():
|
|
183
|
+
torch_wrap_api_name = get_wrap_torch_api_list()
|
|
184
|
+
wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
|
|
185
|
+
{f: getattr(torch, f) for f in dir(torch)},
|
|
186
|
+
MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
|
|
187
|
+
wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
|
|
188
|
+
{f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
|
|
189
|
+
MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
|
|
190
|
+
wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
|
|
191
|
+
{f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
|
|
192
|
+
MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
|
|
193
|
+
wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
|
|
194
|
+
{f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
|
|
195
|
+
MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
|
|
196
|
+
wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
|
|
197
|
+
MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
|
|
198
|
+
return
|
|
199
|
+
|
|
109
200
|
wrap_api_name = get_wrap_api_list()
|
|
110
201
|
wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
|
|
111
202
|
MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
|