mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import io
|
|
17
17
|
import os
|
|
18
|
+
import pickle
|
|
18
19
|
import random
|
|
19
20
|
import stat
|
|
20
21
|
from functools import wraps
|
|
@@ -24,7 +25,7 @@ import torch
|
|
|
24
25
|
import torch.distributed as dist
|
|
25
26
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
26
27
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
27
|
-
check_file_or_directory_path, check_path_before_create)
|
|
28
|
+
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
28
29
|
from msprobe.core.common.log import logger
|
|
29
30
|
from msprobe.core.common.utils import check_seed_all
|
|
30
31
|
from packaging import version
|
|
@@ -75,7 +76,7 @@ def parameter_adapter(func):
|
|
|
75
76
|
else:
|
|
76
77
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
77
78
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
78
|
-
if self.op_name_ == "__eq__" and args[1] is None:
|
|
79
|
+
if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
|
|
79
80
|
return False
|
|
80
81
|
return func(self, *args, **kwargs)
|
|
81
82
|
|
|
@@ -104,8 +105,49 @@ def get_rank_if_initialized():
|
|
|
104
105
|
raise DistributedNotInitializedError("torch distributed environment is not initialized")
|
|
105
106
|
|
|
106
107
|
|
|
107
|
-
def
|
|
108
|
-
|
|
108
|
+
def remove_dropout():
|
|
109
|
+
if torch.__version__ > "1.8":
|
|
110
|
+
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
111
|
+
import torch.nn.functional as F
|
|
112
|
+
from torch import _VF
|
|
113
|
+
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
114
|
+
|
|
115
|
+
def function_dropout(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
116
|
+
inplace: bool = False) -> torch.Tensor:
|
|
117
|
+
if has_torch_function_unary(input_tensor):
|
|
118
|
+
return handle_torch_function(
|
|
119
|
+
function_dropout, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
120
|
+
if p < 0.0 or p > 1.0:
|
|
121
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
122
|
+
return _VF.dropout_(input_tensor, 0., training) if inplace else _VF.dropout(input_tensor, 0., training)
|
|
123
|
+
|
|
124
|
+
def function_dropout2d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
125
|
+
inplace: bool = False) -> torch.Tensor:
|
|
126
|
+
if has_torch_function_unary(input_tensor):
|
|
127
|
+
return handle_torch_function(
|
|
128
|
+
function_dropout2d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
129
|
+
if p < 0.0 or p > 1.0:
|
|
130
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
131
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
132
|
+
0., training)
|
|
133
|
+
|
|
134
|
+
def function_dropout3d(input_tensor: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
135
|
+
inplace: bool = False) -> torch.Tensor:
|
|
136
|
+
if has_torch_function_unary(input_tensor):
|
|
137
|
+
return handle_torch_function(
|
|
138
|
+
function_dropout3d, (input_tensor,), input_tensor, p=0., training=training, inplace=inplace)
|
|
139
|
+
if p < 0.0 or p > 1.0:
|
|
140
|
+
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
141
|
+
return _VF.feature_dropout_(input_tensor, 0., training) if inplace else _VF.feature_dropout(input_tensor,
|
|
142
|
+
0., training)
|
|
143
|
+
|
|
144
|
+
F.dropout = function_dropout
|
|
145
|
+
F.dropout2d = function_dropout2d
|
|
146
|
+
F.dropout3d = function_dropout3d
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
150
|
+
check_seed_all(seed, mode, rm_dropout)
|
|
109
151
|
try:
|
|
110
152
|
random.seed(seed)
|
|
111
153
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
@@ -125,6 +167,8 @@ def seed_all(seed=1234, mode=False):
|
|
|
125
167
|
else:
|
|
126
168
|
torch_npu.npu.manual_seed_all(seed)
|
|
127
169
|
torch_npu.npu.manual_seed(seed)
|
|
170
|
+
if rm_dropout:
|
|
171
|
+
remove_dropout()
|
|
128
172
|
except Exception as e:
|
|
129
173
|
logger.error(f"There is an unexpected error while determinating randomness. {e}")
|
|
130
174
|
|
|
@@ -269,17 +313,17 @@ def load_pt(pt_path, to_cpu=False):
|
|
|
269
313
|
check_file_or_directory_path(pt_path)
|
|
270
314
|
try:
|
|
271
315
|
if to_cpu:
|
|
272
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
316
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
273
317
|
else:
|
|
274
|
-
pt = torch.load(pt_path)
|
|
318
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
275
319
|
except Exception as e:
|
|
276
320
|
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
277
321
|
return pt
|
|
278
322
|
|
|
279
323
|
|
|
280
324
|
def save_pt(tensor, filepath):
|
|
281
|
-
filepath = os.path.realpath(filepath)
|
|
282
325
|
check_path_before_create(filepath)
|
|
326
|
+
filepath = os.path.realpath(filepath)
|
|
283
327
|
try:
|
|
284
328
|
torch.save(tensor, filepath)
|
|
285
329
|
except Exception as e:
|
|
@@ -290,6 +334,56 @@ def save_pt(tensor, filepath):
|
|
|
290
334
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
291
335
|
|
|
292
336
|
|
|
337
|
+
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
338
|
+
"""
|
|
339
|
+
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
340
|
+
It overrides the find_class method to add type checking functionality.
|
|
341
|
+
"""
|
|
342
|
+
allowed_types = [
|
|
343
|
+
"str",
|
|
344
|
+
"ApiData",
|
|
345
|
+
"OrderedDict",
|
|
346
|
+
"_rebuild_tensor_v2", # from torch.utils
|
|
347
|
+
"_load_from_bytes" # from torch.storage
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
def find_class(self, module, name):
|
|
351
|
+
"""
|
|
352
|
+
Method to find the class of the object to be unpickled.
|
|
353
|
+
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
354
|
+
"""
|
|
355
|
+
if name in self.allowed_types:
|
|
356
|
+
return super().find_class(module, name)
|
|
357
|
+
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def save_pkl(tensor, filepath):
|
|
361
|
+
"""Save ApiData or str objection by pickle"""
|
|
362
|
+
check_path_before_create(filepath)
|
|
363
|
+
filepath = os.path.realpath(filepath)
|
|
364
|
+
try:
|
|
365
|
+
with FileOpen(filepath, 'wb') as f:
|
|
366
|
+
pickle.dump(tensor, f)
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.error("Save pt file failed, please check according possible error causes: "
|
|
369
|
+
"1. out of disk space or disk error, "
|
|
370
|
+
"2. no permission to write files, etc.")
|
|
371
|
+
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
372
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def load_pkl(pt_path):
|
|
376
|
+
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
377
|
+
check_file_or_directory_path(pt_path)
|
|
378
|
+
pt_path = os.path.realpath(pt_path)
|
|
379
|
+
try:
|
|
380
|
+
with FileOpen(pt_path, 'rb') as f:
|
|
381
|
+
pt = TypeCheckingUnpickler(f).load()
|
|
382
|
+
except Exception as e:
|
|
383
|
+
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
384
|
+
return pt
|
|
385
|
+
|
|
386
|
+
|
|
293
387
|
def save_api_data(api_data):
|
|
294
388
|
"""Save data to io stream"""
|
|
295
389
|
try:
|
|
@@ -14,53 +14,40 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
-
|
|
18
|
-
check_configuration_param, task_dumppath_get
|
|
19
|
-
from msprobe.core.common.file_utils import create_directory
|
|
17
|
+
|
|
20
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory
|
|
20
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
21
|
+
set_dump_path
|
|
22
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
23
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
|
|
21
24
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
23
|
-
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator, compare
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
27
|
-
if kwargs.get(
|
|
29
|
+
if kwargs.get("suffix"):
|
|
28
30
|
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
29
31
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
30
|
-
|
|
31
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
32
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
32
|
+
is_print_compare_log = kwargs.get("is_print_compare_log", True)
|
|
33
33
|
# get the ranks and match by order
|
|
34
34
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
35
35
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
36
36
|
if len(npu_ranks) != len(bench_ranks):
|
|
37
|
-
logger.error(
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
logger.error(
|
|
38
|
+
"The number of ranks in the two runs are different. "
|
|
39
|
+
"Unable to match the ranks. "
|
|
40
|
+
"Please use another folder to compare or use compare() api and manually match the ranks.")
|
|
40
41
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
41
42
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
42
43
|
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
43
44
|
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
44
45
|
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
45
46
|
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
46
|
-
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
47
47
|
|
|
48
48
|
dump_result_param = {
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
'is_print_compare_log': True
|
|
49
|
+
"npu_json_path": npu_path,
|
|
50
|
+
"bench_json_path": bench_path,
|
|
51
|
+
"is_print_compare_log": is_print_compare_log
|
|
53
52
|
}
|
|
54
|
-
|
|
55
|
-
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
56
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
|
|
57
|
-
dump_result_param.get('is_print_compare_log', True))
|
|
58
|
-
create_directory(output_path)
|
|
59
|
-
check_compare_param(dump_result_param, output_path,
|
|
60
|
-
summary_compare=summary_compare, md5_compare=md5_compare)
|
|
61
|
-
except (CompareException, FileCheckException) as error:
|
|
62
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
|
-
raise CompareException(error.code) from error
|
|
64
|
-
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
66
|
-
summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
|
|
53
|
+
compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
@@ -14,19 +14,29 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os.path
|
|
17
|
+
|
|
17
18
|
import torch
|
|
19
|
+
|
|
18
20
|
from msprobe.core.common.const import FileCheckConst
|
|
19
|
-
from msprobe.pytorch.common.log import logger
|
|
20
21
|
from msprobe.core.common.exceptions import FileCheckException
|
|
21
|
-
from msprobe.core.compare.acc_compare import Comparator
|
|
22
|
-
from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
|
|
23
|
-
CompareException
|
|
24
22
|
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
23
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
24
|
+
set_dump_path
|
|
25
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
26
|
+
from msprobe.core.compare.utils import set_stack_json_path
|
|
27
|
+
from msprobe.pytorch.common.log import logger
|
|
25
28
|
from msprobe.pytorch.common.utils import load_pt
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
class PTComparator
|
|
29
|
-
def __init__(self, data_mapping=None):
|
|
31
|
+
class PTComparator(Comparator):
|
|
32
|
+
def __init__(self, mode_config, data_mapping=None):
|
|
33
|
+
super().__init__(mode_config)
|
|
34
|
+
|
|
35
|
+
self.stack_mode = mode_config.stack_mode
|
|
36
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
37
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
38
|
+
self.dump_mode = mode_config.dump_mode
|
|
39
|
+
|
|
30
40
|
self.frame_name = PTComparator.__name__
|
|
31
41
|
self.data_mapping = data_mapping
|
|
32
42
|
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
@@ -37,21 +47,24 @@ class PTComparator (Comparator):
|
|
|
37
47
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
48
|
f"{type(self.data_mapping)}")
|
|
39
49
|
|
|
40
|
-
|
|
50
|
+
@staticmethod
|
|
51
|
+
def load_mapping_file(mapping_file):
|
|
41
52
|
if isinstance(mapping_file, str):
|
|
42
53
|
mapping_dict = load_yaml(mapping_file)
|
|
43
54
|
else:
|
|
44
55
|
mapping_dict = {}
|
|
45
56
|
return mapping_dict
|
|
46
|
-
|
|
57
|
+
|
|
47
58
|
def read_npy_data(self, dir_path, file_name):
|
|
59
|
+
if not file_name:
|
|
60
|
+
return None
|
|
48
61
|
data_path = os.path.join(dir_path, file_name)
|
|
49
62
|
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
50
|
-
|
|
63
|
+
FileCheckConst.PT_SUFFIX, False)
|
|
51
64
|
data_path = path_checker.common_check()
|
|
52
65
|
try:
|
|
53
|
-
|
|
54
|
-
|
|
66
|
+
# detach because numpy can not process gradient information
|
|
67
|
+
data_value = load_pt(data_path, to_cpu=True).detach()
|
|
55
68
|
except RuntimeError as e:
|
|
56
69
|
# 这里捕获 load_pt 中抛出的异常
|
|
57
70
|
logger.error(f"Failed to load the .pt file at {data_path}.")
|
|
@@ -63,20 +76,29 @@ class PTComparator (Comparator):
|
|
|
63
76
|
if data_value.dtype == torch.bfloat16:
|
|
64
77
|
data_value = data_value.to(torch.float32)
|
|
65
78
|
data_value = data_value.numpy()
|
|
66
|
-
return data_value
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def compare(input_param, output_path,
|
|
79
|
+
return data_value
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compare(input_param, output_path, **kwargs):
|
|
70
83
|
try:
|
|
71
|
-
|
|
84
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
85
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
86
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
87
|
+
suffix = kwargs.get('suffix', '')
|
|
88
|
+
|
|
89
|
+
set_dump_path(input_param)
|
|
90
|
+
dump_mode = get_dump_mode(input_param)
|
|
91
|
+
if "stack_json_path" in input_param:
|
|
92
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
93
|
+
else:
|
|
94
|
+
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
72
95
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
73
96
|
create_directory(output_path)
|
|
74
|
-
check_compare_param(input_param, output_path,
|
|
75
|
-
data_mapping = kwargs.get('data_mapping', None)
|
|
97
|
+
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
76
98
|
except (CompareException, FileCheckException) as error:
|
|
77
99
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
78
100
|
raise CompareException(error.code) from error
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
101
|
+
|
|
102
|
+
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
103
|
+
pt_comparator = PTComparator(mode_config, data_mapping)
|
|
104
|
+
pt_comparator.compare_core(input_param, output_path, suffix=suffix)
|
|
@@ -31,13 +31,14 @@ class DebuggerConfig:
|
|
|
31
31
|
self.scope = task_config.scope if task_config.scope else []
|
|
32
32
|
self.list = task_config.list if task_config.list else []
|
|
33
33
|
self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
|
|
34
|
-
self.backward_input_list = task_config.backward_input if task_config.backward_input else []
|
|
35
|
-
self.backward_input = {}
|
|
36
|
-
self.acl_config = common_config.acl_config if common_config.acl_config else ""
|
|
37
|
-
self.is_forward_acl_dump = True
|
|
38
34
|
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
|
|
39
35
|
self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
|
|
40
36
|
self.framework = Const.PT_FRAMEWORK
|
|
37
|
+
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
38
|
+
|
|
39
|
+
if self.level == Const.LEVEL_L2:
|
|
40
|
+
self.is_backward_kernel_dump = False
|
|
41
|
+
self._check_and_adjust_config_with_l2()
|
|
41
42
|
|
|
42
43
|
if self.task == Const.FREE_BENCHMARK:
|
|
43
44
|
self.fuzz_device = task_config.fuzz_device
|
|
@@ -59,20 +60,11 @@ class DebuggerConfig:
|
|
|
59
60
|
self.tls_path = task_config.tls_path if task_config.tls_path else ""
|
|
60
61
|
self.host = task_config.host if task_config.host else ""
|
|
61
62
|
self.port = task_config.port if task_config.port else -1
|
|
63
|
+
self.online_run_ut_recompute = task_config.online_run_ut_recompute \
|
|
64
|
+
if isinstance(task_config.online_run_ut_recompute, bool) else False
|
|
62
65
|
|
|
63
66
|
self.check()
|
|
64
67
|
|
|
65
|
-
if self.level == "L2":
|
|
66
|
-
if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
|
|
67
|
-
raise ValueError("scope must be configured as a list with one api name")
|
|
68
|
-
if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
|
|
69
|
-
raise ValueError("backward_input must be configured when scope contains 'backward'")
|
|
70
|
-
if Const.BACKWARD in self.scope[0]:
|
|
71
|
-
self.is_forward_acl_dump = False
|
|
72
|
-
for index, scope_spec in enumerate(self.scope):
|
|
73
|
-
self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
|
|
74
|
-
self.backward_input[self.scope[index]] = self.backward_input_list[index]
|
|
75
|
-
|
|
76
68
|
def check_kwargs(self):
|
|
77
69
|
if self.task and self.task not in Const.TASK_LIST:
|
|
78
70
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
@@ -83,26 +75,53 @@ class DebuggerConfig:
|
|
|
83
75
|
if not self.dump_path:
|
|
84
76
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
85
77
|
f"The dump_path not found.")
|
|
78
|
+
if not isinstance(self.async_dump, bool):
|
|
79
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
80
|
+
f"The parameters async_dump should be bool.")
|
|
86
81
|
|
|
87
82
|
def check(self):
|
|
88
83
|
self.check_kwargs()
|
|
89
84
|
return True
|
|
90
85
|
|
|
91
86
|
def check_model(self, instance, start_model):
|
|
92
|
-
if self.level not in [
|
|
87
|
+
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
93
88
|
if instance.model is not None or start_model is not None:
|
|
94
|
-
logger.
|
|
89
|
+
logger.info_on_rank_0(
|
|
95
90
|
f"The current level is not L0 or mix level, so the model parameters will not be used.")
|
|
96
91
|
return
|
|
97
|
-
if start_model is None:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if isinstance(
|
|
104
|
-
|
|
92
|
+
if start_model is None and instance.model is None:
|
|
93
|
+
logger.error_on_rank_0(
|
|
94
|
+
f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
|
|
95
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
|
|
96
|
+
|
|
97
|
+
instance.model = start_model if start_model is not None else instance.model
|
|
98
|
+
if isinstance(instance.model, torch.nn.Module):
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
error_model = None
|
|
102
|
+
if isinstance(instance.model, (list, tuple)):
|
|
103
|
+
for model in instance.model:
|
|
104
|
+
if not isinstance(model, torch.nn.Module):
|
|
105
|
+
error_model = model
|
|
106
|
+
break
|
|
105
107
|
else:
|
|
106
|
-
|
|
108
|
+
error_model = instance.model
|
|
109
|
+
|
|
110
|
+
if error_model is not None:
|
|
111
|
+
error_info = (f"The 'model' parameter must be a torch.nn.Moudle or list[torch.nn.Moudle] "
|
|
112
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
107
113
|
raise MsprobeException(
|
|
108
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
114
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
115
|
+
|
|
116
|
+
def _check_and_adjust_config_with_l2(self):
|
|
117
|
+
if self.scope:
|
|
118
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
119
|
+
f"When level is set to L2, the scope cannot be configured.")
|
|
120
|
+
if not self.list or len(self.list) != 1:
|
|
121
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
122
|
+
f"When level is set to L2, the list must be configured as a list with one api name.")
|
|
123
|
+
api_name = self.list[0]
|
|
124
|
+
if api_name.endswith(Const.BACKWARD):
|
|
125
|
+
self.is_backward_kernel_dump = True
|
|
126
|
+
api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
|
|
127
|
+
self.list.append(api_forward_name)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -22,6 +22,7 @@ from msprobe.core.common.file_utils import FileChecker
|
|
|
22
22
|
from msprobe.core.common.utils import get_real_step_or_rank
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
24
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
25
|
+
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
25
26
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
26
27
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
27
28
|
from msprobe.pytorch.service import Service
|
|
@@ -49,7 +50,7 @@ class PrecisionDebugger:
|
|
|
49
50
|
dump_path=None,
|
|
50
51
|
level=None,
|
|
51
52
|
model=None,
|
|
52
|
-
step=None
|
|
53
|
+
step=None
|
|
53
54
|
):
|
|
54
55
|
if not hasattr(self, "initialized"):
|
|
55
56
|
config_params = ConfigParameters(config_path,
|
|
@@ -59,7 +60,6 @@ class PrecisionDebugger:
|
|
|
59
60
|
model)
|
|
60
61
|
self.check_input_params(config_params)
|
|
61
62
|
|
|
62
|
-
self.api_origin = False
|
|
63
63
|
self.initialized = True
|
|
64
64
|
self.model = model
|
|
65
65
|
common_config, task_config = parse_json_config(config_path, task)
|
|
@@ -67,12 +67,13 @@ class PrecisionDebugger:
|
|
|
67
67
|
if self.task == Const.GRAD_PROBE:
|
|
68
68
|
self.gm = GradientMonitor(common_config, task_config)
|
|
69
69
|
return
|
|
70
|
-
if step:
|
|
70
|
+
if step is not None:
|
|
71
71
|
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
72
72
|
self.config = DebuggerConfig(
|
|
73
73
|
common_config, task_config, task, dump_path, level
|
|
74
74
|
)
|
|
75
75
|
self.service = Service(self.config)
|
|
76
|
+
self.module_dumper = ModuleDumper(self.service)
|
|
76
77
|
self.enable_dataloader = self.config.enable_dataloader
|
|
77
78
|
if self.enable_dataloader:
|
|
78
79
|
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
@@ -105,9 +106,11 @@ class PrecisionDebugger:
|
|
|
105
106
|
raise MsprobeException(
|
|
106
107
|
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
107
108
|
|
|
108
|
-
if args.model is not None
|
|
109
|
-
|
|
110
|
-
|
|
109
|
+
if args.model is not None:
|
|
110
|
+
logger.warning_on_rank_0(
|
|
111
|
+
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
112
|
+
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
113
|
+
)
|
|
111
114
|
|
|
112
115
|
@classmethod
|
|
113
116
|
def start(cls, model=None):
|
|
@@ -120,15 +123,12 @@ class PrecisionDebugger:
|
|
|
120
123
|
if instance.enable_dataloader:
|
|
121
124
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
122
125
|
else:
|
|
123
|
-
instance.service.start(instance.model
|
|
124
|
-
instance.api_origin = False
|
|
126
|
+
instance.service.start(instance.model)
|
|
125
127
|
|
|
126
|
-
# 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
|
|
127
128
|
@classmethod
|
|
128
129
|
def forward_backward_dump_end(cls):
|
|
129
130
|
instance = cls._instance
|
|
130
|
-
instance.
|
|
131
|
-
instance.api_origin = True
|
|
131
|
+
instance.stop()
|
|
132
132
|
|
|
133
133
|
@classmethod
|
|
134
134
|
def stop(cls):
|
|
@@ -159,6 +159,36 @@ class PrecisionDebugger:
|
|
|
159
159
|
cls._instance.gm.monitor(model)
|
|
160
160
|
|
|
161
161
|
|
|
162
|
+
def module_dump(module, dump_name):
|
|
163
|
+
if not isinstance(module, torch.nn.Module):
|
|
164
|
+
raise MsprobeException(
|
|
165
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
166
|
+
f"the module argument in module_dump must be a torch.nn.Module subclass"
|
|
167
|
+
)
|
|
168
|
+
if not isinstance(dump_name, str):
|
|
169
|
+
raise MsprobeException(
|
|
170
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
171
|
+
f"the dump_name argument in module_dump must be a str type"
|
|
172
|
+
)
|
|
173
|
+
instance = PrecisionDebugger._instance
|
|
174
|
+
if not instance:
|
|
175
|
+
raise MsprobeException(
|
|
176
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
177
|
+
f"PrecisionDebugger must be instantiated before using module_dump interface"
|
|
178
|
+
)
|
|
179
|
+
instance.module_dumper.start_module_dump(module, dump_name)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def module_dump_end():
|
|
183
|
+
instance = PrecisionDebugger._instance
|
|
184
|
+
if not instance:
|
|
185
|
+
raise MsprobeException(
|
|
186
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
187
|
+
f"PrecisionDebugger must be instantiated before using module_dump_end interface"
|
|
188
|
+
)
|
|
189
|
+
instance.module_dumper.stop_module_dump()
|
|
190
|
+
|
|
191
|
+
|
|
162
192
|
def iter_tracer(func):
|
|
163
193
|
def func_wrapper(*args, **kwargs):
|
|
164
194
|
debugger_instance = PrecisionDebugger.instance
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import save_json
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_kernel_config_json(dump_path, cur_rank):
|
|
22
|
+
kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
|
|
23
|
+
kernel_config_path = os.path.join(dump_path, kernel_config_name)
|
|
24
|
+
config_info = {
|
|
25
|
+
"dump": {
|
|
26
|
+
"dump_list": [],
|
|
27
|
+
"dump_path": dump_path,
|
|
28
|
+
"dump_mode": "all",
|
|
29
|
+
"dump_op_switch": "on"
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
save_json(kernel_config_path, config_info, indent=4)
|
|
33
|
+
return kernel_config_path
|