mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,36 +13,22 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from torch.utils.data import dataloader
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
18
|
+
from msprobe.core.common.const import Const, MsgConst
|
|
20
19
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
-
from msprobe.core.common.
|
|
22
|
-
from msprobe.core.
|
|
20
|
+
from msprobe.core.common.utils import check_token_range
|
|
21
|
+
from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
|
|
23
22
|
from msprobe.pytorch.common.log import logger
|
|
24
|
-
from msprobe.pytorch.common.utils import check_save_param
|
|
23
|
+
from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
|
|
25
24
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
26
25
|
from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper
|
|
27
26
|
from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
|
|
28
|
-
from msprobe.pytorch.
|
|
29
|
-
from msprobe.pytorch.
|
|
30
|
-
from torch.utils.data import dataloader
|
|
31
|
-
|
|
32
|
-
ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
|
|
33
|
-
"dump_path", "level", "model"])
|
|
34
|
-
|
|
27
|
+
from msprobe.pytorch.pytorch_service import PytorchService
|
|
28
|
+
from msprobe.pytorch.pt_config import parse_task_config
|
|
35
29
|
|
|
36
|
-
class PrecisionDebugger:
|
|
37
|
-
_instance = None
|
|
38
|
-
tasks_not_need_debugger = [Const.GRAD_PROBE]
|
|
39
30
|
|
|
40
|
-
|
|
41
|
-
if cls._instance is None:
|
|
42
|
-
cls._instance = super(PrecisionDebugger, cls).__new__(cls)
|
|
43
|
-
cls._instance.config = None
|
|
44
|
-
cls._instance.enable_dataloader = False
|
|
45
|
-
return cls._instance
|
|
31
|
+
class PrecisionDebugger(BasePrecisionDebugger):
|
|
46
32
|
|
|
47
33
|
def __init__(
|
|
48
34
|
self,
|
|
@@ -53,90 +39,48 @@ class PrecisionDebugger:
|
|
|
53
39
|
model=None,
|
|
54
40
|
step=None
|
|
55
41
|
):
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if step is not None:
|
|
72
|
-
common_config.step = get_real_step_or_rank(step, Const.STEP)
|
|
73
|
-
self.config = DebuggerConfig(
|
|
74
|
-
common_config, task_config, task, dump_path, level
|
|
75
|
-
)
|
|
76
|
-
self.service = Service(self.config)
|
|
77
|
-
self.module_dumper = ModuleDumper(self.service)
|
|
78
|
-
self.enable_dataloader = self.config.enable_dataloader
|
|
79
|
-
if self.enable_dataloader:
|
|
80
|
-
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
81
|
-
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
42
|
+
if self.initialized:
|
|
43
|
+
return
|
|
44
|
+
super().__init__(config_path, task, dump_path, level, step)
|
|
45
|
+
self.model = model
|
|
46
|
+
if self.task == Const.GRAD_PROBE:
|
|
47
|
+
self.gm = GradientMonitor(self.common_config, self.task_config)
|
|
48
|
+
return
|
|
49
|
+
self.config = DebuggerConfig(
|
|
50
|
+
self.common_config, self.task_config, task, dump_path, level
|
|
51
|
+
)
|
|
52
|
+
self.service = PytorchService(self.config)
|
|
53
|
+
self.module_dumper = ModuleDumper(self.service)
|
|
54
|
+
self.ori_customer_func = {}
|
|
55
|
+
self.enable_dataloader = self.config.enable_dataloader
|
|
56
|
+
self.param_warning()
|
|
82
57
|
|
|
83
58
|
@property
|
|
84
59
|
def instance(self):
|
|
85
60
|
return self._instance
|
|
86
61
|
|
|
87
62
|
@staticmethod
|
|
88
|
-
def
|
|
89
|
-
|
|
90
|
-
if not isinstance(args.config_path, str):
|
|
91
|
-
raise MsprobeException(
|
|
92
|
-
MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
|
|
93
|
-
file_checker = FileChecker(
|
|
94
|
-
file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
95
|
-
file_checker.common_check()
|
|
96
|
-
|
|
97
|
-
if args.task is not None and args.task not in Const.TASK_LIST:
|
|
98
|
-
raise MsprobeException(
|
|
99
|
-
MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
|
|
100
|
-
|
|
101
|
-
if args.dump_path is not None:
|
|
102
|
-
if not isinstance(args.dump_path, str):
|
|
103
|
-
raise MsprobeException(
|
|
104
|
-
MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
|
|
105
|
-
|
|
106
|
-
if args.level is not None and args.level not in Const.LEVEL_LIST:
|
|
107
|
-
raise MsprobeException(
|
|
108
|
-
MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
|
|
109
|
-
|
|
110
|
-
if args.model is not None:
|
|
111
|
-
logger.warning_on_rank_0(
|
|
112
|
-
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
113
|
-
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
114
|
-
)
|
|
63
|
+
def get_task_config(task, json_config):
|
|
64
|
+
return parse_task_config(task, json_config)
|
|
115
65
|
|
|
116
66
|
@classmethod
|
|
117
|
-
def start(cls, model=None):
|
|
118
|
-
instance = cls.
|
|
119
|
-
if
|
|
120
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
121
|
-
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
67
|
+
def start(cls, model=None, token_range=None):
|
|
68
|
+
instance = cls.get_instance()
|
|
69
|
+
if instance is None:
|
|
122
70
|
return
|
|
123
|
-
|
|
71
|
+
|
|
72
|
+
check_token_range(token_range)
|
|
73
|
+
instance.config.check_model(instance, model, token_range)
|
|
74
|
+
|
|
124
75
|
if instance.enable_dataloader:
|
|
125
76
|
logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
|
|
126
77
|
else:
|
|
127
|
-
instance.service.start(instance.model)
|
|
128
|
-
|
|
129
|
-
@classmethod
|
|
130
|
-
def forward_backward_dump_end(cls):
|
|
131
|
-
instance = cls._instance
|
|
132
|
-
instance.stop()
|
|
78
|
+
instance.service.start(instance.model, token_range)
|
|
133
79
|
|
|
134
80
|
@classmethod
|
|
135
81
|
def stop(cls):
|
|
136
|
-
instance = cls.
|
|
137
|
-
if
|
|
138
|
-
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
139
|
-
if instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
82
|
+
instance = cls.get_instance()
|
|
83
|
+
if instance is None:
|
|
140
84
|
return
|
|
141
85
|
if instance.enable_dataloader:
|
|
142
86
|
logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
|
|
@@ -145,9 +89,8 @@ class PrecisionDebugger:
|
|
|
145
89
|
|
|
146
90
|
@classmethod
|
|
147
91
|
def step(cls):
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
|
|
92
|
+
instance = cls.get_instance()
|
|
93
|
+
if instance is None:
|
|
151
94
|
return
|
|
152
95
|
cls._instance.service.step()
|
|
153
96
|
|
|
@@ -172,12 +115,23 @@ class PrecisionDebugger:
|
|
|
172
115
|
return
|
|
173
116
|
instance.service.save(variable, name, save_backward)
|
|
174
117
|
|
|
118
|
+
def param_warning(self):
|
|
119
|
+
if self.model is not None:
|
|
120
|
+
logger.warning_on_rank_0(
|
|
121
|
+
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
122
|
+
"It is recommended to pass the 'model' parameter in the start interface instead."
|
|
123
|
+
)
|
|
124
|
+
if self.enable_dataloader:
|
|
125
|
+
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
126
|
+
dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
127
|
+
|
|
175
128
|
|
|
176
129
|
def module_dump(module, dump_name):
|
|
177
|
-
if not
|
|
130
|
+
if not is_torch_nn_module(module):
|
|
178
131
|
raise MsprobeException(
|
|
179
132
|
MsprobeException.INVALID_PARAM_ERROR,
|
|
180
|
-
f"the module argument in module_dump must be a torch.nn.Module
|
|
133
|
+
f"the module argument in module_dump must be a torch.nn.Module type, "
|
|
134
|
+
f"but currently there is an unsupported {type(module)} type."
|
|
181
135
|
)
|
|
182
136
|
if not isinstance(dump_name, str):
|
|
183
137
|
raise MsprobeException(
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from functools import wraps
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.hooks import BackwardHook
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
|
+
from msprobe.pytorch.common.log import logger
|
|
24
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def wrap_setup_backward_hook(func):
|
|
28
|
+
def requires_clone(tensor):
|
|
29
|
+
return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \
|
|
30
|
+
tensor.requires_grad and torch.is_grad_enabled()
|
|
31
|
+
|
|
32
|
+
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
33
|
+
def parse_tensor(item, tensor_list):
|
|
34
|
+
if requires_clone(item):
|
|
35
|
+
tensor_list.append(item)
|
|
36
|
+
elif isinstance(item, (list, tuple)):
|
|
37
|
+
for value in item:
|
|
38
|
+
parse_tensor(value, tensor_list)
|
|
39
|
+
elif isinstance(item, dict):
|
|
40
|
+
for value in item.values():
|
|
41
|
+
parse_tensor(value, tensor_list)
|
|
42
|
+
|
|
43
|
+
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH)
|
|
44
|
+
def rebuild_args(item, tensor_iter):
|
|
45
|
+
if requires_clone(item):
|
|
46
|
+
result = next(tensor_iter)
|
|
47
|
+
if hasattr(result, "_base") and result._base is not None:
|
|
48
|
+
if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0):
|
|
49
|
+
torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0))
|
|
50
|
+
return result
|
|
51
|
+
if isinstance(item, list):
|
|
52
|
+
for index, value in enumerate(item):
|
|
53
|
+
item[index] = rebuild_args(value, tensor_iter)
|
|
54
|
+
return item
|
|
55
|
+
if isinstance(item, dict):
|
|
56
|
+
for key, value in item.items():
|
|
57
|
+
item[key] = rebuild_args(value, tensor_iter)
|
|
58
|
+
return item
|
|
59
|
+
if isinstance(item, tuple):
|
|
60
|
+
if hasattr(item, '_fields'):
|
|
61
|
+
return type(item)(*[rebuild_args(i, tensor_iter) for i in item])
|
|
62
|
+
return type(item)([rebuild_args(i, tensor_iter) for i in item])
|
|
63
|
+
return item
|
|
64
|
+
|
|
65
|
+
@wraps(func)
|
|
66
|
+
def wrap_setup_hook_func(*args, **kwargs):
|
|
67
|
+
if len(args) < 2:
|
|
68
|
+
return func(*args, **kwargs)
|
|
69
|
+
|
|
70
|
+
actual_args = args[1]
|
|
71
|
+
|
|
72
|
+
tensor_list = []
|
|
73
|
+
|
|
74
|
+
parse_tensor(actual_args, tensor_list)
|
|
75
|
+
|
|
76
|
+
new_args = args[0], tuple(tensor_list)
|
|
77
|
+
hooked_tensors = func(*new_args, **kwargs)
|
|
78
|
+
|
|
79
|
+
tensor_iter = iter(hooked_tensors)
|
|
80
|
+
try:
|
|
81
|
+
new_data = rebuild_args(actual_args, tensor_iter)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}")
|
|
84
|
+
new_data = actual_args
|
|
85
|
+
|
|
86
|
+
return new_data
|
|
87
|
+
|
|
88
|
+
return wrap_setup_hook_func
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def wrap_setup_input_output_hook():
|
|
92
|
+
BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook)
|
|
93
|
+
BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook)
|
|
@@ -13,74 +13,28 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.data_dump.scope import BaseScope
|
|
19
16
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.
|
|
21
|
-
|
|
22
|
-
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
17
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
18
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
23
19
|
|
|
24
20
|
|
|
25
21
|
class ModuleDumper:
|
|
26
22
|
def __init__(self, service):
|
|
27
23
|
self.service = service
|
|
28
|
-
self.
|
|
24
|
+
self.api_register = get_api_register()
|
|
29
25
|
|
|
30
26
|
def start_module_dump(self, module, dump_name):
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def stop_module_dump(self):
|
|
35
|
-
api_register.api_modularity()
|
|
36
|
-
for hook_handle in self.hook_handle_list:
|
|
37
|
-
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
|
-
hook_handle.remove()
|
|
39
|
-
self.hook_handle_list.clear()
|
|
27
|
+
if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'):
|
|
28
|
+
logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.")
|
|
29
|
+
return
|
|
40
30
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
module_processor = self.service.module_processor
|
|
48
|
-
_, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook(
|
|
49
|
-
BaseScope.Module_Type_Module,
|
|
50
|
-
prefix_name
|
|
51
|
-
)
|
|
31
|
+
ModuleProcesser.enable_module_dump = True
|
|
32
|
+
self.api_register.restore_all_api()
|
|
33
|
+
if not hasattr(module, 'msprobe_module_dump'):
|
|
34
|
+
self.service.module_processor.register_module_hook(module, self.service.build_hook,
|
|
35
|
+
recursive=False, module_names=[dump_name])
|
|
36
|
+
setattr(module, 'msprobe_module_dump', True)
|
|
52
37
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
f"which may cause abnormal data dump. The backward data dump for this module will be skipped."
|
|
57
|
-
)
|
|
58
|
-
if torch_version_above_or_equal_2:
|
|
59
|
-
forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
60
|
-
else:
|
|
61
|
-
if not module_processor.has_register_backward_hook(module):
|
|
62
|
-
backward_hook_handle = module.register_full_backward_hook(
|
|
63
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
64
|
-
)
|
|
65
|
-
self.hook_handle_list.append(backward_hook_handle)
|
|
66
|
-
forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
67
|
-
self.hook_handle_list.append(forward_hook_handle)
|
|
68
|
-
if not module_processor.has_register_backward_hook(module):
|
|
69
|
-
backward_hook_handle = module.register_full_backward_hook(backward_hook)
|
|
70
|
-
self.hook_handle_list.append(backward_hook_handle)
|
|
71
|
-
|
|
72
|
-
forward_pre_hook_handle = module.register_forward_pre_hook(
|
|
73
|
-
module_processor.node_hook(prefix_name + Const.FORWARD, Const.START)
|
|
74
|
-
)
|
|
75
|
-
forward_hook_handle = module.register_forward_hook(
|
|
76
|
-
module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP)
|
|
77
|
-
)
|
|
78
|
-
self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle])
|
|
79
|
-
if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module):
|
|
80
|
-
backward_pre_hook_handle = module.register_full_backward_pre_hook(
|
|
81
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START)
|
|
82
|
-
)
|
|
83
|
-
backward_hook_handle = module.register_full_backward_hook(
|
|
84
|
-
module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP)
|
|
85
|
-
)
|
|
86
|
-
self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle])
|
|
38
|
+
def stop_module_dump(self):
|
|
39
|
+
ModuleProcesser.enable_module_dump = False
|
|
40
|
+
self.api_register.register_all_api()
|