mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,18 +13,20 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from collections import OrderedDict
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
+
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
20
|
+
|
|
19
21
|
from msprobe.core.common.const import Const
|
|
20
22
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
23
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.common.utils import
|
|
23
|
-
from
|
|
24
|
-
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
25
|
-
from torch.utils.hooks import BackwardHook
|
|
24
|
+
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
25
|
+
from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook
|
|
26
26
|
|
|
27
27
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def checkpoint_without_early_stop(*args, **kwargs):
|
|
@@ -33,7 +35,18 @@ def checkpoint_without_early_stop(*args, **kwargs):
|
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def replace_checkpoint():
|
|
36
|
-
|
|
38
|
+
if torch_version_above_or_equal_2:
|
|
39
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def wrap_megatron_deallocate(func):
|
|
43
|
+
def wrapper_func(out, deallocate_pipeline_outputs=False):
|
|
44
|
+
if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None:
|
|
45
|
+
out_clone = out.clone()
|
|
46
|
+
out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
|
|
47
|
+
return func(out_clone, deallocate_pipeline_outputs)
|
|
48
|
+
return func(out, deallocate_pipeline_outputs)
|
|
49
|
+
return wrapper_func
|
|
37
50
|
|
|
38
51
|
|
|
39
52
|
class ModuleProcesser:
|
|
@@ -41,37 +54,25 @@ class ModuleProcesser:
|
|
|
41
54
|
module_stack = []
|
|
42
55
|
api_parent_node = ""
|
|
43
56
|
module_node = {}
|
|
57
|
+
module_bw_hook_kernels = {}
|
|
58
|
+
module_with_backward_hook = {}
|
|
59
|
+
enable_module_dump = False
|
|
44
60
|
|
|
45
61
|
def __init__(self, scope):
|
|
46
62
|
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
47
|
-
|
|
48
|
-
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
|
|
63
|
+
wrap_setup_input_output_hook()
|
|
49
64
|
replace_checkpoint()
|
|
65
|
+
try:
|
|
66
|
+
from megatron.core.pipeline_parallel import schedules
|
|
67
|
+
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
68
|
+
logger.info_on_rank_0("Patch megatron method success.")
|
|
69
|
+
except ImportError:
|
|
70
|
+
logger.info_on_rank_0("No megatron find.")
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
|
|
50
73
|
|
|
51
74
|
@staticmethod
|
|
52
|
-
def
|
|
53
|
-
@wraps(func)
|
|
54
|
-
def clone_return_value_func(*args, **kwargs):
|
|
55
|
-
result = func(*args, **kwargs)
|
|
56
|
-
return ModuleProcesser.clone_if_tensor(result)
|
|
57
|
-
|
|
58
|
-
return clone_return_value_func
|
|
59
|
-
|
|
60
|
-
@staticmethod
|
|
61
|
-
def clone_if_tensor(result):
|
|
62
|
-
if isinstance(result, torch.Tensor):
|
|
63
|
-
return result.clone()
|
|
64
|
-
elif type(result) is tuple:
|
|
65
|
-
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
66
|
-
elif type(result) is list:
|
|
67
|
-
return list(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
68
|
-
elif type(result) is dict:
|
|
69
|
-
return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
|
|
70
|
-
else:
|
|
71
|
-
return result
|
|
72
|
-
|
|
73
|
-
@staticmethod
|
|
74
|
-
def module_count_func(module_name):
|
|
75
|
+
def set_and_get_calls_number(module_name):
|
|
75
76
|
if module_name not in ModuleProcesser.module_count:
|
|
76
77
|
ModuleProcesser.module_count[module_name] = 0
|
|
77
78
|
else:
|
|
@@ -85,13 +86,19 @@ class ModuleProcesser:
|
|
|
85
86
|
module._is_full_backward_hook is False
|
|
86
87
|
|
|
87
88
|
@staticmethod
|
|
88
|
-
def get_modules_and_names(models):
|
|
89
|
+
def get_modules_and_names(models, recursive, module_names):
|
|
89
90
|
modules_and_names_with_index = {}
|
|
90
91
|
if isinstance(models, (list, tuple)):
|
|
92
|
+
if not recursive and len(module_names) != len(models):
|
|
93
|
+
return modules_and_names_with_index
|
|
91
94
|
for index, model in enumerate(models):
|
|
92
|
-
modules_and_names_with_index[str(index)] = model.named_modules()
|
|
95
|
+
modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \
|
|
96
|
+
[(module_names[index], model)]
|
|
93
97
|
else:
|
|
94
|
-
|
|
98
|
+
if not recursive and len(module_names) != 1:
|
|
99
|
+
return modules_and_names_with_index
|
|
100
|
+
modules_and_names_with_index["-1"] = models.named_modules() if recursive else \
|
|
101
|
+
[(module_names[0], models)]
|
|
95
102
|
return modules_and_names_with_index
|
|
96
103
|
|
|
97
104
|
@classmethod
|
|
@@ -100,105 +107,134 @@ class ModuleProcesser:
|
|
|
100
107
|
cls.module_stack = []
|
|
101
108
|
cls.api_parent_node = ""
|
|
102
109
|
cls.module_node = {}
|
|
110
|
+
cls.module_bw_hook_kernels = {}
|
|
111
|
+
cls.enable_module_dump = False
|
|
112
|
+
|
|
113
|
+
def register_module_hook(self, models, build_hook, recursive=True, module_names=None):
|
|
114
|
+
if module_names is None:
|
|
115
|
+
module_names = []
|
|
103
116
|
|
|
104
|
-
|
|
105
|
-
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
106
|
-
modules_and_names_with_index = self.get_modules_and_names(models)
|
|
117
|
+
modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names)
|
|
107
118
|
for index, modules_and_names in modules_and_names_with_index.items():
|
|
108
119
|
model = models if index == "-1" else models[int(index)]
|
|
109
120
|
for name, module in modules_and_names:
|
|
110
|
-
if module == model:
|
|
121
|
+
if recursive and module == model:
|
|
111
122
|
continue
|
|
123
|
+
if not is_torch_nn_module(module):
|
|
124
|
+
logger.warning(
|
|
125
|
+
f"The module dump does not support {type(module)} type. "
|
|
126
|
+
f"The data dump for this module will be skipped."
|
|
127
|
+
)
|
|
128
|
+
continue
|
|
129
|
+
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
130
|
+
continue
|
|
131
|
+
setattr(module, 'msprobe_hook', True)
|
|
112
132
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
113
|
-
prefix_name =
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
prefix_name
|
|
118
|
-
)
|
|
133
|
+
prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \
|
|
134
|
+
f'{module.__class__.__name__}{Const.SEP}'
|
|
135
|
+
|
|
136
|
+
forward_pre_hook = self.build_module_hook(prefix_name, build_hook)
|
|
119
137
|
|
|
120
138
|
if self.has_register_backward_hook(module):
|
|
121
139
|
logger.warning(
|
|
122
140
|
f"The {prefix_name[:-1]} has registered deprecated register_backward_hook,"
|
|
123
141
|
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
124
142
|
)
|
|
143
|
+
ModuleProcesser.module_with_backward_hook[prefix_name] = True
|
|
144
|
+
register_forward_pre_hook(module, forward_pre_hook)
|
|
145
|
+
|
|
146
|
+
def build_module_hook(self, module_name, build_data_hook):
|
|
147
|
+
def forward_pre_hook(module, args, kwargs=None):
|
|
148
|
+
if kwargs is None:
|
|
149
|
+
kwargs = {}
|
|
150
|
+
|
|
151
|
+
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
152
|
+
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
153
|
+
|
|
154
|
+
index = ModuleProcesser.set_and_get_calls_number(module_name)
|
|
155
|
+
full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
156
|
+
full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
157
|
+
|
|
158
|
+
self.set_construct_info_in_pre_hook(full_forward_name)
|
|
159
|
+
|
|
160
|
+
if not hasattr(module, 'msprobe_forward_hook'):
|
|
161
|
+
forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict())
|
|
162
|
+
handle = RemovableHandle(forward_hooks_dict)
|
|
163
|
+
forward_hooks_dict[handle.id] = forward_hook
|
|
164
|
+
forward_hooks_dict.move_to_end(handle.id, last=False)
|
|
165
|
+
if torch_version_above_or_equal_2:
|
|
166
|
+
forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict())
|
|
167
|
+
forward_hooks_with_kwargs_dict[handle.id] = True
|
|
168
|
+
|
|
169
|
+
setattr(module, 'msprobe_forward_hook', True)
|
|
170
|
+
|
|
171
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
172
|
+
|
|
173
|
+
def get_backward_pre_hook(full_backward_name):
|
|
174
|
+
def backward_pre_hook_fn(module, grad_output):
|
|
175
|
+
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
176
|
+
return backward_pre_hook_fn
|
|
177
|
+
|
|
178
|
+
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
179
|
+
def backward_hook_fn(module, grad_input, grad_output):
|
|
180
|
+
new_output = backward_data_hook(module, grad_input, grad_output)
|
|
181
|
+
self.set_construct_info_in_hook(full_backward_name, is_forward=False)
|
|
182
|
+
return new_output
|
|
183
|
+
return backward_hook_fn
|
|
184
|
+
|
|
185
|
+
if not ModuleProcesser.module_with_backward_hook.get(module_name):
|
|
186
|
+
backward_pre_hook = get_backward_pre_hook(full_backward_name)
|
|
187
|
+
backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name)
|
|
125
188
|
if torch_version_above_or_equal_2:
|
|
126
|
-
module
|
|
189
|
+
bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook])
|
|
127
190
|
else:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
index = None
|
|
147
|
-
pass
|
|
148
|
-
full_name = name_prefix + Const.SEP + str(index)
|
|
149
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
150
|
-
module.mindstudio_reserved_name = []
|
|
151
|
-
module.mindstudio_reserved_name.append(full_name)
|
|
152
|
-
if self.module_stack:
|
|
153
|
-
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
191
|
+
bw_hook = BackwardHook(module, [backward_hook])
|
|
192
|
+
ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook
|
|
193
|
+
args = bw_hook.setup_input_hook(args)
|
|
194
|
+
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
195
|
+
|
|
196
|
+
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
197
|
+
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
198
|
+
return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
199
|
+
|
|
200
|
+
index = ModuleProcesser.module_count.get(module_name)
|
|
201
|
+
full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
202
|
+
|
|
203
|
+
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name)
|
|
204
|
+
hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs)
|
|
205
|
+
self.set_construct_info_in_hook(full_name)
|
|
206
|
+
|
|
207
|
+
if hook_result is not None:
|
|
208
|
+
result = hook_result
|
|
154
209
|
else:
|
|
155
|
-
|
|
210
|
+
result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
156
211
|
|
|
157
|
-
ModuleProcesser.
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
if self.scope:
|
|
161
|
-
self.scope.begin_module(full_name)
|
|
212
|
+
bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name)
|
|
213
|
+
if bw_hook:
|
|
214
|
+
result = bw_hook.setup_output_hook(result)
|
|
162
215
|
|
|
163
|
-
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
return forward_pre_hook
|
|
219
|
+
|
|
220
|
+
def set_construct_info_in_pre_hook(self, full_name):
|
|
221
|
+
if self.module_stack:
|
|
222
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[-1]
|
|
223
|
+
else:
|
|
224
|
+
ModuleProcesser.module_node[full_name] = None
|
|
225
|
+
ModuleProcesser.module_stack.append(full_name)
|
|
226
|
+
ModuleProcesser.api_parent_node = full_name
|
|
227
|
+
if self.scope:
|
|
228
|
+
self.scope.begin_module(full_name)
|
|
229
|
+
|
|
230
|
+
def set_construct_info_in_hook(self, full_name, is_forward=True):
|
|
231
|
+
if torch_version_above_or_equal_2 or is_forward:
|
|
164
232
|
if self.module_stack:
|
|
165
233
|
ModuleProcesser.module_stack.pop()
|
|
166
|
-
if self.module_stack
|
|
167
|
-
ModuleProcesser.api_parent_node = self.module_stack[-1]
|
|
168
|
-
else:
|
|
169
|
-
ModuleProcesser.api_parent_node = None
|
|
170
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
171
|
-
raise RuntimeError(f"module reserve name is None when pop")
|
|
172
|
-
current_name = module.mindstudio_reserved_name.pop()
|
|
234
|
+
ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None
|
|
173
235
|
if self.scope:
|
|
174
|
-
self.scope.end_module(
|
|
175
|
-
|
|
176
|
-
def backward_hook(module, input, output=None):
|
|
177
|
-
try:
|
|
178
|
-
index = ModuleProcesser.module_count_func(name_prefix)
|
|
179
|
-
except IndexError as e:
|
|
180
|
-
index = None
|
|
181
|
-
pass
|
|
182
|
-
full_name = name_prefix + Const.SEP + str(index)
|
|
183
|
-
if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name:
|
|
184
|
-
module.mindstudio_reserved_name = []
|
|
185
|
-
module.mindstudio_reserved_name.append(full_name)
|
|
186
|
-
forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD)
|
|
187
|
-
ModuleProcesser.module_node[full_name] = replace_last_occurrence(
|
|
188
|
-
ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD)
|
|
189
|
-
ModuleProcesser.api_parent_node = None
|
|
236
|
+
self.scope.end_module(full_name)
|
|
237
|
+
else:
|
|
190
238
|
if self.scope:
|
|
191
239
|
self.scope.begin_module(full_name)
|
|
192
|
-
|
|
193
|
-
if torch_version_above_or_equal_2:
|
|
194
|
-
if Const.START in start_or_stop:
|
|
195
|
-
return pre_hook
|
|
196
|
-
else:
|
|
197
|
-
return end_hook
|
|
198
|
-
else:
|
|
199
|
-
if Const.FORWARD in name_prefix and Const.START in start_or_stop:
|
|
200
|
-
return pre_hook
|
|
201
|
-
elif Const.BACKWARD in name_prefix:
|
|
202
|
-
return backward_hook
|
|
203
|
-
else:
|
|
204
|
-
return end_hook
|
|
240
|
+
ModuleProcesser.api_parent_node = full_name
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
21
21
|
|
|
22
22
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import math
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
22
22
|
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
|
|
|
95
95
|
except Exception:
|
|
96
96
|
logger.warning_on_rank_0(
|
|
97
97
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
98
|
-
f"when
|
|
98
|
+
f"when calculating the maximum value, the tensor is changed to float32."
|
|
99
99
|
)
|
|
100
100
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
101
101
|
if max_val < abs_tol:
|
|
102
102
|
logger.warning_on_rank_0(
|
|
103
103
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
104
|
-
f"
|
|
104
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
105
105
|
)
|
|
106
106
|
return False
|
|
107
107
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
100
100
|
except Exception:
|
|
101
101
|
logger.warning_on_rank_0(
|
|
102
102
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
103
|
-
f"when calculate
|
|
103
|
+
f"when calculate the maximum value, the tensor is changed to float32."
|
|
104
104
|
)
|
|
105
105
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
106
106
|
if max_val < abs_tol:
|
|
107
107
|
logger.warning_on_rank_0(
|
|
108
108
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
109
|
-
f"
|
|
109
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
110
110
|
)
|
|
111
111
|
return False
|
|
112
112
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.common.
|
|
18
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
19
19
|
from msprobe.pytorch.free_benchmark import logger
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
|
|
|
49
49
|
except Exception as e:
|
|
50
50
|
logger.warning_on_rank_0(
|
|
51
51
|
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
52
|
-
f"when
|
|
52
|
+
f"when comparing the results, an exception is raised: {e}"
|
|
53
53
|
)
|
|
54
54
|
return data_params.original_result
|
|
@@ -70,7 +70,7 @@ class Register(dict):
|
|
|
70
70
|
|
|
71
71
|
def add_register_item(key, value):
|
|
72
72
|
if key in self._dict:
|
|
73
|
-
logger.warning(f"{value.__name__} has been registered before, so we will
|
|
73
|
+
logger.warning(f"{value.__name__} has been registered before, so we will override it.")
|
|
74
74
|
self[key] = value
|
|
75
75
|
return value
|
|
76
76
|
|
|
@@ -46,7 +46,7 @@ class GradientMonitor:
|
|
|
46
46
|
if not os.path.exists(self._output_path):
|
|
47
47
|
create_directory(self._output_path)
|
|
48
48
|
else:
|
|
49
|
-
logger.warning(f"the file in {self._output_path} will be
|
|
49
|
+
logger.warning(f"the file in {self._output_path} will be deleted")
|
|
50
50
|
self._step = -1
|
|
51
51
|
self._param2name = defaultdict(str)
|
|
52
52
|
|
|
@@ -97,7 +97,7 @@ class GradientMonitor:
|
|
|
97
97
|
create_directory(output_dirpath)
|
|
98
98
|
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
99
99
|
if os.path.exists(output_path):
|
|
100
|
-
logger.warning(f"{output_path} will be
|
|
100
|
+
logger.warning(f"{output_path} will be deleted")
|
|
101
101
|
remove_path(output_path)
|
|
102
102
|
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
103
103
|
output_lines.insert(0, header_result)
|
|
@@ -17,6 +17,7 @@ from abc import ABC, abstractmethod
|
|
|
17
17
|
from collections import namedtuple
|
|
18
18
|
import hashlib
|
|
19
19
|
from functools import wraps
|
|
20
|
+
import zlib
|
|
20
21
|
import torch
|
|
21
22
|
from msprobe.core.grad_probe.constant import GradConst
|
|
22
23
|
|
|
@@ -74,8 +75,8 @@ class CsvMd5(CsvItem):
|
|
|
74
75
|
def generate_csv_content(csv_content_input):
|
|
75
76
|
grad = csv_content_input.grad
|
|
76
77
|
tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
|
|
77
|
-
md5_hash =
|
|
78
|
-
return [md5_hash
|
|
78
|
+
md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
|
|
79
|
+
return [md5_hash]
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
@register_csv_item(GradConst.DISTRIBUTION)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import os
|
|
18
|
+
import inspect
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.distributed as dist
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
25
|
+
from msprobe.pytorch.common.log import logger
|
|
26
|
+
from msprobe.pytorch.common.utils import (
|
|
27
|
+
torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
|
|
28
|
+
)
|
|
29
|
+
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
30
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
31
|
+
from msprobe.pytorch.hook_module.utils import dynamic_import_op
|
|
32
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import mindspeed.ops
|
|
36
|
+
except ImportError:
|
|
37
|
+
mindspeed_enable = False
|
|
38
|
+
else:
|
|
39
|
+
mindspeed_enable = True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
43
|
+
|
|
44
|
+
_inner_used_api = {}
|
|
45
|
+
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
46
|
+
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
47
|
+
|
|
48
|
+
_api_types = {
|
|
49
|
+
Const.PT_FRAMEWORK: {
|
|
50
|
+
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
51
|
+
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
52
|
+
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
53
|
+
Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
|
|
54
|
+
Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
if not is_gpu:
|
|
58
|
+
import torch_npu
|
|
59
|
+
if torch_without_guard_version:
|
|
60
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
61
|
+
{
|
|
62
|
+
Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
|
|
63
|
+
}
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
67
|
+
{Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
|
|
68
|
+
)
|
|
69
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
70
|
+
{
|
|
71
|
+
Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
|
|
72
|
+
torch_npu.distributed.distributed_c10d))
|
|
73
|
+
}
|
|
74
|
+
)
|
|
75
|
+
if mindspeed_enable:
|
|
76
|
+
_api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))})
|
|
77
|
+
mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED)
|
|
78
|
+
mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list]
|
|
79
|
+
dynamic_import_op(mindspeed.ops, mindspeed_op_file_list)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@parameter_adapter
|
|
83
|
+
def tensor_module_forward(module, *args, **kwargs):
|
|
84
|
+
return module.api_func(*args, **kwargs)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def dist_module_forward(module, *args, **kwargs):
|
|
88
|
+
handle = module.api_func(*args, **kwargs)
|
|
89
|
+
try:
|
|
90
|
+
bound = inspect.signature(module.api_func).bind(*args, **kwargs)
|
|
91
|
+
bound.apply_defaults()
|
|
92
|
+
use_asyn_op_flag = bound.arguments.get("asyn_op", False)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
use_asyn_op_flag = False
|
|
95
|
+
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
96
|
+
|
|
97
|
+
if use_asyn_op_flag or module.api_name in ["isend", "irecv"]:
|
|
98
|
+
if handle and hasattr(handle, 'wait'):
|
|
99
|
+
handle.wait()
|
|
100
|
+
if module.api_name == "batch_isend_irecv":
|
|
101
|
+
if isinstance(handle, list):
|
|
102
|
+
for req in handle:
|
|
103
|
+
req.wait()
|
|
104
|
+
return handle
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def npu_module_forward(module, *args, **kwargs):
|
|
108
|
+
if not module.need_hook:
|
|
109
|
+
if module.api_name not in npu_custom_functions:
|
|
110
|
+
raise Exception(f'There is not bench function {module.api_name}')
|
|
111
|
+
if module.device == Const.CUDA_LOWERCASE:
|
|
112
|
+
module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
|
|
113
|
+
if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
|
|
114
|
+
return npu_custom_functions[module.api_name](*args, **kwargs)
|
|
115
|
+
return module.api_func(*args, **kwargs)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
forward_methods = {
|
|
119
|
+
"Tensor": tensor_module_forward,
|
|
120
|
+
"Distributed": dist_module_forward,
|
|
121
|
+
"NPU": npu_module_forward
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ApiTemplate(HOOKModule):
|
|
126
|
+
def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
|
|
127
|
+
self.api_name = api_name
|
|
128
|
+
self.api_func = api_func
|
|
129
|
+
self.prefix = prefix
|
|
130
|
+
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
131
|
+
self.need_hook = need_hook
|
|
132
|
+
self.device = device
|
|
133
|
+
if self.need_hook:
|
|
134
|
+
super().__init__(hook_build_func)
|
|
135
|
+
if prefix == Const.DIST_API_TYPE_PREFIX:
|
|
136
|
+
self.op_is_distributed = True
|
|
137
|
+
|
|
138
|
+
@torch_device_guard
|
|
139
|
+
def forward(self, *args, **kwargs):
|
|
140
|
+
exec_func = forward_methods.get(self.prefix)
|
|
141
|
+
exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
|
|
142
|
+
return exec_func(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
api_register = None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_api_register(return_new=False):
|
|
149
|
+
if return_new:
|
|
150
|
+
return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
151
|
+
|
|
152
|
+
global api_register
|
|
153
|
+
if api_register is None:
|
|
154
|
+
api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
155
|
+
return api_register
|