mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,1076 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import uuid
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from functools import partial
|
|
21
|
+
|
|
22
|
+
import pytz
|
|
23
|
+
import torch
|
|
24
|
+
import torch.distributed as dist
|
|
25
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
26
|
+
from torch.utils.hooks import BackwardHook
|
|
27
|
+
|
|
28
|
+
from msprobe.core.common.const import MonitorConst
|
|
29
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
30
|
+
from msprobe.pytorch.common.log import logger
|
|
31
|
+
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
32
|
+
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
33
|
+
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
34
|
+
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
35
|
+
get_process_group
|
|
36
|
+
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
37
|
+
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
38
|
+
TensorMetrics, squash_param_name
|
|
39
|
+
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
40
|
+
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
|
|
41
|
+
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \
|
|
42
|
+
get_output_base_dir, get_target_output_dir
|
|
43
|
+
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
44
|
+
|
|
45
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
46
|
+
if not torch_version_above_or_equal_2:
|
|
47
|
+
raise ValueError("monitor require torch>=2.0")
|
|
48
|
+
|
|
49
|
+
FORMAT_MAPPING = {
|
|
50
|
+
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
51
|
+
MonitorConst.CSV: CSVWriterWithAD,
|
|
52
|
+
MonitorConst.API: BaseWriterWithAD
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def param_is_not_tensor_parallel_duplicate(param, tp_group):
|
|
57
|
+
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
|
|
58
|
+
torch.distributed.get_rank(group=tp_group) == 0
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def param_is_data_parallel_duplicate(dp_group):
|
|
63
|
+
return torch.distributed.get_rank(group=dp_group) != 0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ModuleHookContext:
|
|
67
|
+
def __init__(self, module_name) -> None:
|
|
68
|
+
self.micro_step = 0
|
|
69
|
+
self.actv = defaultdict(dict)
|
|
70
|
+
self.actvgrad = []
|
|
71
|
+
self.module_name = module_name
|
|
72
|
+
self.struct = {}
|
|
73
|
+
self.format_by_arg = {}
|
|
74
|
+
self.verified = False
|
|
75
|
+
self.focused_in_col = 0
|
|
76
|
+
self.focused_out_col = 0
|
|
77
|
+
|
|
78
|
+
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
79
|
+
""" 按照监控对象配置format_by_arg
|
|
80
|
+
1) module_name 在 target 中配置监控对象
|
|
81
|
+
2) module_name 未在 targets 中配置,且 all_xy 全量监控
|
|
82
|
+
3) module_name 未在 targets 中配置,且 all_xy 未全量监控
|
|
83
|
+
|
|
84
|
+
:param key_name: str, one of [input, output, input_grad, output_grad]
|
|
85
|
+
:param target_config: target obj in config json.
|
|
86
|
+
:return:
|
|
87
|
+
"""
|
|
88
|
+
valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
89
|
+
if key_name not in valid_key:
|
|
90
|
+
raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
|
|
91
|
+
cared = target_config.get(self.module_name, self.struct)
|
|
92
|
+
if key_name in cared:
|
|
93
|
+
target_module_config = cared[key_name]
|
|
94
|
+
if isinstance(target_module_config, dict):
|
|
95
|
+
# current cared is self.struct, monitor all data for module_name
|
|
96
|
+
self.format_by_arg[key_name] = target_module_config.get('config')
|
|
97
|
+
elif isinstance(target_module_config, str):
|
|
98
|
+
# current cared is target_config[self.module_name]
|
|
99
|
+
self.format_by_arg[key_name] = target_module_config
|
|
100
|
+
else:
|
|
101
|
+
logger.warning_on_rank_0(f"target module config error, result maybe empty."
|
|
102
|
+
f"module_name: {self.module_name}, key_name: {key_name}")
|
|
103
|
+
self.format_by_arg[key_name] = None
|
|
104
|
+
else:
|
|
105
|
+
self.format_by_arg[key_name] = self.struct.get(key_name).get('config')
|
|
106
|
+
|
|
107
|
+
def reset(self):
|
|
108
|
+
self.actv.clear()
|
|
109
|
+
self.actvgrad.clear()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
start_step = 0
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class OptimizerContext:
|
|
116
|
+
def __init__(self) -> None:
|
|
117
|
+
self.step = start_step
|
|
118
|
+
self.param_mg_direction = defaultdict(float)
|
|
119
|
+
self.param_adam_update = defaultdict()
|
|
120
|
+
self.param_adam_ratio = defaultdict()
|
|
121
|
+
self.param_weight_grad = defaultdict()
|
|
122
|
+
self.param_exp_avg = defaultdict()
|
|
123
|
+
self.exp_avg_metric = {}
|
|
124
|
+
self.param_exp_avg_sq = defaultdict()
|
|
125
|
+
self.exp_avg_sq_metric = {}
|
|
126
|
+
self.metric_dict = {}
|
|
127
|
+
self.param_metric = {}
|
|
128
|
+
|
|
129
|
+
def reset(self):
|
|
130
|
+
self.param_mg_direction.clear()
|
|
131
|
+
self.param_adam_update.clear()
|
|
132
|
+
self.param_adam_ratio.clear()
|
|
133
|
+
self.param_weight_grad.clear()
|
|
134
|
+
self.param_exp_avg.clear()
|
|
135
|
+
self.exp_avg_metric.clear()
|
|
136
|
+
self.param_exp_avg_sq.clear()
|
|
137
|
+
self.exp_avg_sq_metric.clear()
|
|
138
|
+
self.metric_dict.clear()
|
|
139
|
+
self.param_metric.clear()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class CommunicationContext:
|
|
143
|
+
def __init__(self) -> None:
|
|
144
|
+
self.data = {}
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def _agg(data):
|
|
148
|
+
aggregated_data = {}
|
|
149
|
+
for tag, op2tensorlist in data.items():
|
|
150
|
+
aggregated_data[tag] = {}
|
|
151
|
+
for op, tensorlist in op2tensorlist.items():
|
|
152
|
+
aggregated_data[tag][op] = op_aggregate(op, tensorlist)
|
|
153
|
+
return aggregated_data
|
|
154
|
+
|
|
155
|
+
def reset(self):
|
|
156
|
+
self.data = {}
|
|
157
|
+
|
|
158
|
+
def aggregate(self):
|
|
159
|
+
self.data = self._agg(self.data)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class GradContext:
|
|
163
|
+
def __init__(self) -> None:
|
|
164
|
+
self.pre = {}
|
|
165
|
+
self.post = {}
|
|
166
|
+
self.acc_metric = {}
|
|
167
|
+
self.acc = {}
|
|
168
|
+
self.actv = {}
|
|
169
|
+
|
|
170
|
+
def reset(self):
|
|
171
|
+
self.pre.clear()
|
|
172
|
+
self.post.clear()
|
|
173
|
+
self.acc_metric.clear()
|
|
174
|
+
self.acc.clear()
|
|
175
|
+
self.actv.clear()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class TrainerMon:
|
|
179
|
+
tensor_metrics = TensorMetrics()
|
|
180
|
+
|
|
181
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
|
|
182
|
+
"""
|
|
183
|
+
opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
|
|
184
|
+
"""
|
|
185
|
+
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
186
|
+
self.config_file_path = config_file_path
|
|
187
|
+
self.process_group = get_process_group(process_group)
|
|
188
|
+
self.params_have_main_grad = params_have_main_grad
|
|
189
|
+
self.opt_ty = opt_ty
|
|
190
|
+
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
191
|
+
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
192
|
+
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
193
|
+
self.origin_step_func = None
|
|
194
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过switch开关直接打开
|
|
195
|
+
self.config = load_json(config_file_path)
|
|
196
|
+
validate_config(self.config)
|
|
197
|
+
|
|
198
|
+
self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上
|
|
199
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
200
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
201
|
+
self.unique_id = str(uuid.uuid4())[:8]
|
|
202
|
+
self.output_base_dir = get_output_base_dir()
|
|
203
|
+
time_tags = self.config.get("append_output", [])
|
|
204
|
+
if dist.is_initialized():
|
|
205
|
+
self.rank = dist.get_rank()
|
|
206
|
+
if time_tags:
|
|
207
|
+
output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
|
|
208
|
+
if str(self.rank) in output_append_dirs:
|
|
209
|
+
self.tensorboard_dir = output_append_dirs[str(self.rank)]
|
|
210
|
+
logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}")
|
|
211
|
+
else:
|
|
212
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir,
|
|
213
|
+
f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
214
|
+
self.pp_stage = dist.get_group_rank(self.process_group, self.rank)
|
|
215
|
+
self.group_mates = dist.get_process_group_ranks(self.process_group)
|
|
216
|
+
else:
|
|
217
|
+
self.rank = 0
|
|
218
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
219
|
+
self.pp_stage = 0
|
|
220
|
+
self.group_mates = [0]
|
|
221
|
+
|
|
222
|
+
# TYPE2: 只会在monitor_gnorm_with_ad()主调中赋值的变量
|
|
223
|
+
self.model = None
|
|
224
|
+
self.vpp = False
|
|
225
|
+
self.dp_group = None
|
|
226
|
+
self.tp_group = None
|
|
227
|
+
self.enable_megatron = False
|
|
228
|
+
self.micro_batch_number = 1
|
|
229
|
+
|
|
230
|
+
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
231
|
+
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
232
|
+
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
233
|
+
self.optimizer_context = defaultdict(OptimizerContext)
|
|
234
|
+
self.cc_context = defaultdict(CommunicationContext)
|
|
235
|
+
self.grad_context = GradContext()
|
|
236
|
+
self.handles = defaultdict(list)
|
|
237
|
+
self.param2name = defaultdict(str)
|
|
238
|
+
self.name2index = defaultdict()
|
|
239
|
+
self.name2indices = defaultdict()
|
|
240
|
+
self.name2param = {}
|
|
241
|
+
self.duplicate_param = {}
|
|
242
|
+
self.name2tag = {}
|
|
243
|
+
self.param_name_call_id = {}
|
|
244
|
+
self.call_id = 0
|
|
245
|
+
self.module_struct = defaultdict(dict)
|
|
246
|
+
self.grad_accs = []
|
|
247
|
+
self.weight_hooked = False
|
|
248
|
+
self.optimizer_hooked = False
|
|
249
|
+
self.param_registered = False
|
|
250
|
+
self.struct_printed = False
|
|
251
|
+
|
|
252
|
+
# 动静态区分
|
|
253
|
+
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
254
|
+
if self.dynamic_enable:
|
|
255
|
+
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
256
|
+
f"please make sure you have 'switch' and 'collect_times' item in {self.config_file_path}")
|
|
257
|
+
self.monitoring = False
|
|
258
|
+
else:
|
|
259
|
+
self.set_config()
|
|
260
|
+
# 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
|
|
261
|
+
if self.collect_times > 0:
|
|
262
|
+
self.monitoring = True
|
|
263
|
+
|
|
264
|
+
def __del__(self):
|
|
265
|
+
if hasattr(self, "summary_writer"):
|
|
266
|
+
self.summary_writer.close()
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def ops(self):
|
|
270
|
+
return self._ops
|
|
271
|
+
|
|
272
|
+
@ops.setter
|
|
273
|
+
def ops(self, value):
|
|
274
|
+
self._ops = validate_ops(value)
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
278
|
+
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def has_register_backward_hook(module_name, module):
|
|
282
|
+
if hasattr(module, '_backward_hooks') and \
|
|
283
|
+
len(module._backward_hooks) > 0 and \
|
|
284
|
+
module._is_full_backward_hook is False:
|
|
285
|
+
logger.warning(
|
|
286
|
+
f"The {module_name} has registered deprecated register_backward_hook,"
|
|
287
|
+
f"which may cause abnormal data dump. The backward input/output for this module will be skipped."
|
|
288
|
+
)
|
|
289
|
+
return True
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def generate_cc_metrics(cc_name, cc_tensor):
|
|
294
|
+
metrics = defaultdict(dict)
|
|
295
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
296
|
+
for op, tag2tensor in cc_tensor.data.items():
|
|
297
|
+
for tag, tensor in tag2tensor.items():
|
|
298
|
+
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
299
|
+
metrics[op].update({key: tensor})
|
|
300
|
+
cc_tensor.reset()
|
|
301
|
+
return metrics
|
|
302
|
+
|
|
303
|
+
def set_config(self):
|
|
304
|
+
logger.info(f"current config: {self.config}")
|
|
305
|
+
self.start_step = self.config.get("start_step", 0)
|
|
306
|
+
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
307
|
+
self.step_interval = self.config.get("step_interval", 1)
|
|
308
|
+
self.has_collect_times = 0 # 重设采集计数器
|
|
309
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
310
|
+
self.module_rank_list = self.config.get("module_ranks", [])
|
|
311
|
+
self.format = self.config.get('format', 'tensorboard')
|
|
312
|
+
self.eps = self.config.get('eps', 1e-8)
|
|
313
|
+
self.ops = self.config.get('ops', [])
|
|
314
|
+
self.ndigits = self.config.get('ndigits', 6)
|
|
315
|
+
self.all_xy = self.config.get('all_xy', False)
|
|
316
|
+
self.xy_distribution = self.config.get('xy_distribution', False)
|
|
317
|
+
self.forward_only = self.config.get('forward_only', False)
|
|
318
|
+
self.backward_only = self.config.get('backward_only', False)
|
|
319
|
+
self.ur_distribution = self.config.get('ur_distribution', False)
|
|
320
|
+
self.mv_distribution = self.config.get("mv_distribution", False)
|
|
321
|
+
self.wg_distribution = self.config.get("wg_distribution", False)
|
|
322
|
+
self.param_distribution = self.config.get("param_distribution", False)
|
|
323
|
+
self.mg_direction = self.config.get('mg_direction', False)
|
|
324
|
+
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
325
|
+
|
|
326
|
+
if not self.cc_distribution.get('enable', False):
|
|
327
|
+
self.cc_log_only = False
|
|
328
|
+
else:
|
|
329
|
+
self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
|
|
330
|
+
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
331
|
+
self.cc_logged_stack = defaultdict(set)
|
|
332
|
+
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
333
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
334
|
+
api_register.redirect_api()
|
|
335
|
+
|
|
336
|
+
self.common_info()
|
|
337
|
+
|
|
338
|
+
# 初始化AnomalyData工厂
|
|
339
|
+
alert_setting = self.config.get('alert', {"rules": []})
|
|
340
|
+
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
341
|
+
self.anomaly_data_factory = None
|
|
342
|
+
if alert_setting.get('dump', False):
|
|
343
|
+
self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
|
|
344
|
+
|
|
345
|
+
# 初始化writer, 创建输出目录
|
|
346
|
+
if self.format not in FORMAT_MAPPING:
|
|
347
|
+
raise ValueError(f"Unsupported format: {self.format}")
|
|
348
|
+
writer = FORMAT_MAPPING[self.format]
|
|
349
|
+
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
350
|
+
|
|
351
|
+
if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
352
|
+
self.summary_writer = writer(
|
|
353
|
+
WriterInput(
|
|
354
|
+
self.tensorboard_dir,
|
|
355
|
+
self.alert_rules,
|
|
356
|
+
self.unique_id,
|
|
357
|
+
self.anomaly_data_factory,
|
|
358
|
+
self.ndigits,
|
|
359
|
+
self.step_count_per_record
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
# 初始化anomaly detected文件目录
|
|
363
|
+
if self.anomaly_data_factory:
|
|
364
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
|
|
365
|
+
self.rank)
|
|
366
|
+
self.anomaly_data_writer.init_detected_json()
|
|
367
|
+
|
|
368
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
369
|
+
rank = None
|
|
370
|
+
if dist.is_initialized():
|
|
371
|
+
rank = dist.get_rank()
|
|
372
|
+
if (rank not in rank_list) and len(rank_list) != 0:
|
|
373
|
+
return
|
|
374
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
375
|
+
|
|
376
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
377
|
+
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
378
|
+
self._register_param_call_id("_hook_module", key)
|
|
379
|
+
return {key: tensor}
|
|
380
|
+
|
|
381
|
+
def common_info(self):
|
|
382
|
+
if not self.xy_distribution:
|
|
383
|
+
logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
|
|
384
|
+
if self.forward_only:
|
|
385
|
+
logger.info_on_rank_0("> only module forward is monitored. ")
|
|
386
|
+
if not self.ur_distribution:
|
|
387
|
+
logger.info_on_rank_0("> update vector and ratio vector of adam is not monitored. ")
|
|
388
|
+
if not self.mv_distribution:
|
|
389
|
+
logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
|
|
390
|
+
if not self.wg_distribution:
|
|
391
|
+
logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
|
|
392
|
+
if not self.mg_direction:
|
|
393
|
+
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
|
|
394
|
+
if not self.cc_distribution.get('enable', False):
|
|
395
|
+
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
396
|
+
if not self.opt_ty:
|
|
397
|
+
if self.ur_distribution:
|
|
398
|
+
raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
|
|
399
|
+
if self.mv_distribution:
|
|
400
|
+
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
|
|
401
|
+
|
|
402
|
+
def hook_modules(self):
|
|
403
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
targets = self.config['targets']
|
|
407
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
|
|
408
|
+
for key in module_in_all_stage:
|
|
409
|
+
struct = targets.pop(key)
|
|
410
|
+
targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
411
|
+
|
|
412
|
+
hooked_count = 0
|
|
413
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
414
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
415
|
+
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
416
|
+
'targets'].keys()
|
|
417
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
418
|
+
|
|
419
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
420
|
+
|
|
421
|
+
def clone_if_tensor(args):
|
|
422
|
+
if isinstance(args, tuple):
|
|
423
|
+
return tuple([clone_if_tensor(arg) for arg in args])
|
|
424
|
+
elif isinstance(args, torch.Tensor):
|
|
425
|
+
return args.clone()
|
|
426
|
+
else:
|
|
427
|
+
return args
|
|
428
|
+
|
|
429
|
+
@torch.no_grad
|
|
430
|
+
def wrap_hook_setup(setup):
|
|
431
|
+
def wrapped_setup(*args, **kwargs):
|
|
432
|
+
args = setup(*args, **kwargs)
|
|
433
|
+
args = clone_if_tensor(args)
|
|
434
|
+
return args
|
|
435
|
+
|
|
436
|
+
return wrapped_setup
|
|
437
|
+
|
|
438
|
+
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
439
|
+
|
|
440
|
+
return
|
|
441
|
+
|
|
442
|
+
def generate_param_metrics(self, opt_context):
|
|
443
|
+
if not self.param_distribution:
|
|
444
|
+
return
|
|
445
|
+
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
446
|
+
|
|
447
|
+
def generate_mv_metrics(self, opt_context):
|
|
448
|
+
if not self.mv_distribution:
|
|
449
|
+
return
|
|
450
|
+
opt_context.exp_avg_metric = {}
|
|
451
|
+
opt_context.exp_avg_sq_metric = {}
|
|
452
|
+
m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
|
|
453
|
+
v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
|
|
454
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
455
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
456
|
+
|
|
457
|
+
def generate_wgrad_metrics(self):
|
|
458
|
+
if not self.wg_distribution:
|
|
459
|
+
return {}, {}
|
|
460
|
+
|
|
461
|
+
if self.weight_hooked:
|
|
462
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
463
|
+
|
|
464
|
+
grad_dict = {}
|
|
465
|
+
for param, name in self.param2name.items():
|
|
466
|
+
if self.duplicate_param.get(name, False):
|
|
467
|
+
continue
|
|
468
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
469
|
+
if grad is None:
|
|
470
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
471
|
+
continue
|
|
472
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
473
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
474
|
+
grad_dict[tag] = grad
|
|
475
|
+
|
|
476
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
477
|
+
return self.grad_context.post, self.grad_context.pre
|
|
478
|
+
|
|
479
|
+
def monitor_gnorm_with_ad(
|
|
480
|
+
self,
|
|
481
|
+
model,
|
|
482
|
+
grad_acc_steps=1,
|
|
483
|
+
optimizer=None,
|
|
484
|
+
tp_group=None,
|
|
485
|
+
dp_group=None,
|
|
486
|
+
start_iteration=0
|
|
487
|
+
):
|
|
488
|
+
"""External interface"""
|
|
489
|
+
global start_step
|
|
490
|
+
start_step = start_iteration
|
|
491
|
+
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
492
|
+
self.micro_batch_number = grad_acc_steps
|
|
493
|
+
self.dp_group = dp_group
|
|
494
|
+
self.tp_group = tp_group
|
|
495
|
+
self.hook_step_final(optimizer)
|
|
496
|
+
if not isinstance(model, list):
|
|
497
|
+
model = [model]
|
|
498
|
+
self.model = model
|
|
499
|
+
if len(model) > 1:
|
|
500
|
+
self.vpp = True
|
|
501
|
+
self._smallest_rank_print('vpp enabled')
|
|
502
|
+
if not self.dynamic_enable:
|
|
503
|
+
self.register_hooks(optimizer)
|
|
504
|
+
|
|
505
|
+
def register_hooks(self, optimizer):
|
|
506
|
+
self._register_param_name()
|
|
507
|
+
self.hook_optimizer(optimizer)
|
|
508
|
+
self._patch_grad_sync()
|
|
509
|
+
self.hook_modules()
|
|
510
|
+
self.monitoring = True
|
|
511
|
+
|
|
512
|
+
def generate_param_map(self, tag, param_tensor):
|
|
513
|
+
metrics = {}
|
|
514
|
+
for name in self.param2name.values():
|
|
515
|
+
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
516
|
+
self._register_param_call_id("optimizer_pre_step_hook", key)
|
|
517
|
+
if name not in param_tensor or param_tensor[name] is None:
|
|
518
|
+
continue
|
|
519
|
+
metrics[key] = param_tensor[name]
|
|
520
|
+
return metrics
|
|
521
|
+
|
|
522
|
+
def generate_xy_metrics(self):
|
|
523
|
+
actv = {}
|
|
524
|
+
for fwd_context in self.module_fwd_hook_context_by_module.values():
|
|
525
|
+
actv.update(fwd_context.actv)
|
|
526
|
+
|
|
527
|
+
actv_grad = self.grad_context.actv
|
|
528
|
+
|
|
529
|
+
return actv, actv_grad
|
|
530
|
+
|
|
531
|
+
def reload_xy(self, xy_distribution=False):
|
|
532
|
+
self.xy_distribution = xy_distribution
|
|
533
|
+
|
|
534
|
+
for handle in self.handles['xy']:
|
|
535
|
+
handle.remove()
|
|
536
|
+
self.handles['xy'].clear()
|
|
537
|
+
self.hook_modules()
|
|
538
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
539
|
+
fwd_context.actv.clear()
|
|
540
|
+
|
|
541
|
+
def write_adhoc_check(self, step):
|
|
542
|
+
self.tensor_metrics.flush(self.summary_writer)
|
|
543
|
+
|
|
544
|
+
def write_xy_tb(self, step):
|
|
545
|
+
if not self.xy_distribution:
|
|
546
|
+
return
|
|
547
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
548
|
+
if len(fwd_context.actv) == 0:
|
|
549
|
+
continue
|
|
550
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
|
|
551
|
+
fwd_context.actv.clear()
|
|
552
|
+
if self.grad_context.actv:
|
|
553
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
|
|
554
|
+
|
|
555
|
+
def write_param_tb(self, opt_context):
|
|
556
|
+
if not self.param_distribution:
|
|
557
|
+
return
|
|
558
|
+
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
|
|
559
|
+
|
|
560
|
+
def write_mv_tb(self, opt_context):
|
|
561
|
+
if not self.mv_distribution:
|
|
562
|
+
return
|
|
563
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
|
|
564
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
|
|
565
|
+
|
|
566
|
+
def write_grad_tb(self, step):
|
|
567
|
+
if not self.wg_distribution:
|
|
568
|
+
return
|
|
569
|
+
|
|
570
|
+
if self.enable_megatron:
|
|
571
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
572
|
+
else:
|
|
573
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
574
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
575
|
+
|
|
576
|
+
def hook_optimizer(self, optimizer=None):
|
|
577
|
+
# in DDP by default use params_have_main_grad
|
|
578
|
+
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
579
|
+
context = self.optimizer_context[optimizer]
|
|
580
|
+
|
|
581
|
+
if (self.print_struct and not all(value == {} for value in self.module_struct.values())
|
|
582
|
+
and not self.struct_printed):
|
|
583
|
+
self._save_module_struct()
|
|
584
|
+
if not self.cc_log_only:
|
|
585
|
+
raise Exception("exit after first monitor step when print model struct")
|
|
586
|
+
if self.cc_log_only and context.step > 0:
|
|
587
|
+
self._smallest_rank_print("> Used communication ops and corresponding stack")
|
|
588
|
+
self._smallest_rank_print(
|
|
589
|
+
json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
|
|
590
|
+
raise Exception("exit after first step when print cc stack")
|
|
591
|
+
|
|
592
|
+
# skip generate metrics
|
|
593
|
+
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
594
|
+
return
|
|
595
|
+
if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
|
|
596
|
+
if not self.name2indices:
|
|
597
|
+
self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
|
|
598
|
+
self.name2index)
|
|
599
|
+
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
|
|
600
|
+
self.name2indices)
|
|
601
|
+
self.param2name = mv_result.grad
|
|
602
|
+
else:
|
|
603
|
+
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
604
|
+
context.param_exp_avg = mv_result.exp_avg
|
|
605
|
+
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
606
|
+
context.param_adam_update = mv_result.update
|
|
607
|
+
context.param_adam_ratio = mv_result.ratio
|
|
608
|
+
|
|
609
|
+
self.generate_wgrad_metrics()
|
|
610
|
+
self.generate_mv_metrics(context)
|
|
611
|
+
self.generate_param_metrics(context)
|
|
612
|
+
|
|
613
|
+
tbtag_tensor_map = {}
|
|
614
|
+
if self.mg_direction:
|
|
615
|
+
for param, name in self.param2name.items():
|
|
616
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
617
|
+
if grad is None:
|
|
618
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
619
|
+
continue
|
|
620
|
+
if context.step == 0:
|
|
621
|
+
same_direction_ratio = torch.tensor(1.)
|
|
622
|
+
else:
|
|
623
|
+
same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name])
|
|
624
|
+
context.param_mg_direction[name] = same_direction_ratio
|
|
625
|
+
tbtag_tensor_map.update(self.generate_param_map('mg_direction', context.param_mg_direction))
|
|
626
|
+
|
|
627
|
+
metric_dict = {}
|
|
628
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, metric_dict)
|
|
629
|
+
for cc in self.cc_context.values():
|
|
630
|
+
cc.aggregate()
|
|
631
|
+
metric_dict.update(cc.data)
|
|
632
|
+
cc.reset()
|
|
633
|
+
|
|
634
|
+
if not metric_dict:
|
|
635
|
+
return
|
|
636
|
+
context.metric_dict = metric_dict
|
|
637
|
+
return
|
|
638
|
+
|
|
639
|
+
def patch_step(func, optimizer):
|
|
640
|
+
def wrapper(*args, **kwargs):
|
|
641
|
+
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
642
|
+
out = func(*args, **kwargs)
|
|
643
|
+
return out
|
|
644
|
+
|
|
645
|
+
return wrapper
|
|
646
|
+
|
|
647
|
+
if self.optimizer_hooked:
|
|
648
|
+
return
|
|
649
|
+
|
|
650
|
+
if optimizer:
|
|
651
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
652
|
+
self.handles['optimizer'] = []
|
|
653
|
+
else:
|
|
654
|
+
if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
|
|
655
|
+
step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
656
|
+
self.handles['optimizer'] = [step_pre_hook]
|
|
657
|
+
self.optimizer_hooked = True
|
|
658
|
+
return
|
|
659
|
+
|
|
660
|
+
def dynamic_monitor(self, optimizer):
|
|
661
|
+
"""
|
|
662
|
+
If dynamic monitor enabled and config.json updated,
|
|
663
|
+
remove hooks and register new hooks according to new configuration.
|
|
664
|
+
"""
|
|
665
|
+
context = self.optimizer_context[optimizer]
|
|
666
|
+
if not self.dynamic_enable:
|
|
667
|
+
return
|
|
668
|
+
try:
|
|
669
|
+
# 如果文件时间戳没变, 可以不读取节省时间
|
|
670
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
671
|
+
if config_timestamp == self.config_timestamp:
|
|
672
|
+
return
|
|
673
|
+
# 更新config文件最新修改时间戳
|
|
674
|
+
self.config_timestamp = config_timestamp
|
|
675
|
+
config = load_json(self.config_file_path)
|
|
676
|
+
except Exception as e:
|
|
677
|
+
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
678
|
+
return
|
|
679
|
+
|
|
680
|
+
if config.get("switch", False):
|
|
681
|
+
try:
|
|
682
|
+
validate_config(config)
|
|
683
|
+
self.config = config
|
|
684
|
+
self.set_config()
|
|
685
|
+
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
686
|
+
f"will start new hook at step{context.step}.")
|
|
687
|
+
except Exception as e:
|
|
688
|
+
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
689
|
+
return
|
|
690
|
+
|
|
691
|
+
self._remove_all_hooks(optimizer)
|
|
692
|
+
self.register_hooks(optimizer)
|
|
693
|
+
|
|
694
|
+
def hook_step_final(self, optimizer):
|
|
695
|
+
def step_final_hook(optimizer, args, kwargs):
|
|
696
|
+
context = self.optimizer_context[optimizer]
|
|
697
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
698
|
+
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
699
|
+
if self.monitoring:
|
|
700
|
+
module_rank_valid = not self.module_rank_list or (
|
|
701
|
+
dist.is_initialized() and dist.get_rank() in self.module_rank_list)
|
|
702
|
+
step_condition = (context.step >= self.start_step and (
|
|
703
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
704
|
+
if module_rank_valid and step_condition:
|
|
705
|
+
self.has_collect_times += 1
|
|
706
|
+
|
|
707
|
+
if self.anomaly_data_factory:
|
|
708
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
709
|
+
self.write_xy_tb(context.step)
|
|
710
|
+
self.write_grad_tb(context.step)
|
|
711
|
+
self.write_mv_tb(context)
|
|
712
|
+
self.write_param_tb(context)
|
|
713
|
+
self.write_adhoc_check(context.step)
|
|
714
|
+
|
|
715
|
+
if self.ur_distribution:
|
|
716
|
+
for param_name, _ in context.param_adam_update.items():
|
|
717
|
+
self.update_heatmap_visualizer[param_name].visualize(
|
|
718
|
+
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step,
|
|
719
|
+
self.summary_writer)
|
|
720
|
+
for param_name, _ in context.param_adam_ratio.items():
|
|
721
|
+
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
722
|
+
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step,
|
|
723
|
+
self.summary_writer)
|
|
724
|
+
|
|
725
|
+
if context.metric_dict:
|
|
726
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
727
|
+
context.metric_dict.clear()
|
|
728
|
+
|
|
729
|
+
if self.anomaly_data_factory:
|
|
730
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
731
|
+
self.summary_writer.clear_anomalies()
|
|
732
|
+
self.call_id = 0
|
|
733
|
+
self.param_name_call_id.clear()
|
|
734
|
+
|
|
735
|
+
if self.has_collect_times >= self.collect_times:
|
|
736
|
+
self._remove_all_hooks_final(optimizer)
|
|
737
|
+
|
|
738
|
+
context.step += 1
|
|
739
|
+
self.dynamic_monitor(optimizer)
|
|
740
|
+
|
|
741
|
+
def patch_step(func, optimizer):
|
|
742
|
+
def wrapper(*args, **kwargs):
|
|
743
|
+
out = func(*args, **kwargs)
|
|
744
|
+
step_final_hook(optimizer, args, kwargs)
|
|
745
|
+
return out
|
|
746
|
+
return wrapper
|
|
747
|
+
|
|
748
|
+
if optimizer:
|
|
749
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
750
|
+
self.origin_step_func = optimizer.__class__.step
|
|
751
|
+
else:
|
|
752
|
+
register_optimizer_step_post_hook(step_final_hook)
|
|
753
|
+
return
|
|
754
|
+
|
|
755
|
+
def _remove_all_hooks(self, optimizer):
|
|
756
|
+
# 清空hook handle
|
|
757
|
+
for handle in self.handles['xy']:
|
|
758
|
+
handle.remove()
|
|
759
|
+
self.handles['xy'].clear()
|
|
760
|
+
# 清空对应context缓存
|
|
761
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
762
|
+
fwd_context.reset()
|
|
763
|
+
for _, bwd_context in self.module_bwd_hook_context_by_module.items():
|
|
764
|
+
bwd_context.reset()
|
|
765
|
+
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
766
|
+
|
|
767
|
+
for handle in self.handles['wgrads']:
|
|
768
|
+
handle.remove()
|
|
769
|
+
self.handles['wgrads'].clear()
|
|
770
|
+
self.weight_hooked = False
|
|
771
|
+
|
|
772
|
+
if len(self.handles['optimizer']) == 0 and self.optimizer_hooked:
|
|
773
|
+
optimizer.__class__.step = self.origin_step_func
|
|
774
|
+
else:
|
|
775
|
+
for handle in self.handles['optimizer']:
|
|
776
|
+
handle.remove()
|
|
777
|
+
self.handles['optimizer'].clear()
|
|
778
|
+
for _, context in self.optimizer_context.items():
|
|
779
|
+
context.reset()
|
|
780
|
+
self.optimizer_hooked = False
|
|
781
|
+
|
|
782
|
+
for handle in self.handles['cc']:
|
|
783
|
+
handle.remove()
|
|
784
|
+
self.handles['cc'].clear()
|
|
785
|
+
for _, context in self.cc_context.items():
|
|
786
|
+
context.reset()
|
|
787
|
+
|
|
788
|
+
# 清空节点缓存
|
|
789
|
+
self.param2name.clear()
|
|
790
|
+
self.name2index.clear()
|
|
791
|
+
self.name2indices.clear()
|
|
792
|
+
self.name2param.clear()
|
|
793
|
+
self.duplicate_param.clear()
|
|
794
|
+
self.name2tag.clear()
|
|
795
|
+
self.module_struct.clear()
|
|
796
|
+
self.grad_accs.clear()
|
|
797
|
+
|
|
798
|
+
# 关闭采集状态
|
|
799
|
+
self.monitoring = False
|
|
800
|
+
|
|
801
|
+
def _remove_all_hooks_final(self, optimizer):
|
|
802
|
+
if self.dynamic_enable:
|
|
803
|
+
# 结束后自动重置switch为False等待用户手动开启
|
|
804
|
+
try:
|
|
805
|
+
config = load_json(self.config_file_path)
|
|
806
|
+
config['switch'] = False
|
|
807
|
+
save_json(self.config_file_path, config, indent=2)
|
|
808
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
809
|
+
self.config_timestamp = config_timestamp
|
|
810
|
+
logger.info(
|
|
811
|
+
"Finish monitor, set config'switch=False, will restart by set switch=True and update content")
|
|
812
|
+
except Exception as e:
|
|
813
|
+
logger.warning(f"Finish monitor, set config'switch=False fail because {e}, please check!!!")
|
|
814
|
+
logger.info("Finish monitor")
|
|
815
|
+
self._remove_all_hooks(optimizer)
|
|
816
|
+
|
|
817
|
+
def _smallest_rank_print(self, msg):
|
|
818
|
+
if dist.is_initialized():
|
|
819
|
+
if self.module_rank_list:
|
|
820
|
+
if dist.get_rank() == min(self.module_rank_list):
|
|
821
|
+
logger.info(msg)
|
|
822
|
+
else:
|
|
823
|
+
if dist.get_rank() == 0:
|
|
824
|
+
logger.info(msg)
|
|
825
|
+
else:
|
|
826
|
+
logger.info(msg)
|
|
827
|
+
|
|
828
|
+
def _save_module_struct(self):
|
|
829
|
+
save_module_struct = (not dist.is_initialized()
|
|
830
|
+
or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list))
|
|
831
|
+
or (not self.module_rank_list and dist.get_rank() == 0))
|
|
832
|
+
|
|
833
|
+
if save_module_struct:
|
|
834
|
+
module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
|
|
835
|
+
save_json(module_struct_file, self.module_struct, indent=2)
|
|
836
|
+
logger.info(f"> save module struct to {module_struct_file}")
|
|
837
|
+
self.struct_printed = True
|
|
838
|
+
|
|
839
|
+
def _is_target_param(self, param_name, param, prefix):
|
|
840
|
+
name = prefix + param_name
|
|
841
|
+
squash_name = prefix + squash_param_name(param_name, self.squash_name)
|
|
842
|
+
for target in self.config['targets'].keys():
|
|
843
|
+
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
844
|
+
setattr(param, "zero_out_wgrad", True)
|
|
845
|
+
return True
|
|
846
|
+
|
|
847
|
+
return False
|
|
848
|
+
|
|
849
|
+
def _register_chunk(self, model_chunk, prefix):
|
|
850
|
+
index = 0
|
|
851
|
+
for (param_name, param) in model_chunk.named_parameters():
|
|
852
|
+
if not param.requires_grad:
|
|
853
|
+
continue
|
|
854
|
+
if self._is_target_param(param_name, param, prefix):
|
|
855
|
+
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
856
|
+
if name in self.param2name.values():
|
|
857
|
+
name = prefix + param_name
|
|
858
|
+
self.param2name[param] = name
|
|
859
|
+
self.name2param[name] = param
|
|
860
|
+
self.name2index[name] = index
|
|
861
|
+
|
|
862
|
+
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
863
|
+
self.duplicate_param[name] = True
|
|
864
|
+
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
865
|
+
self.duplicate_param[name] = True
|
|
866
|
+
self.name2tag[name] = {
|
|
867
|
+
MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
|
|
868
|
+
MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
|
|
869
|
+
}
|
|
870
|
+
index += 1
|
|
871
|
+
|
|
872
|
+
def _register_param_name(self):
|
|
873
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
874
|
+
prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
875
|
+
self._register_chunk(model_chunk, prefix)
|
|
876
|
+
|
|
877
|
+
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
878
|
+
if self.all_xy or self.print_struct:
|
|
879
|
+
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
880
|
+
for pattern in [
|
|
881
|
+
vpp_stage + squash_param_name(module_name, self.squash_name),
|
|
882
|
+
vpp_stage + module_name,
|
|
883
|
+
]:
|
|
884
|
+
if pattern in targets:
|
|
885
|
+
return pattern
|
|
886
|
+
return ""
|
|
887
|
+
|
|
888
|
+
def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
|
|
889
|
+
if '_modules' not in module.__dict__:
|
|
890
|
+
# nothing to hook
|
|
891
|
+
return 0
|
|
892
|
+
|
|
893
|
+
def fwd_hook_fun(module, module_input, module_output, name):
|
|
894
|
+
if not module.training or is_recomputation():
|
|
895
|
+
# 1 only monitor training stage.
|
|
896
|
+
# 2 when open recompute, skip recomputed forward stage.
|
|
897
|
+
return
|
|
898
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
899
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
900
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
901
|
+
if not context.struct:
|
|
902
|
+
context.struct = {
|
|
903
|
+
MonitorConst.ACTV_IN: get_param_struct(module_input),
|
|
904
|
+
MonitorConst.ACTV_OUT: get_param_struct(module_output)
|
|
905
|
+
}
|
|
906
|
+
if self.print_struct:
|
|
907
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
908
|
+
return
|
|
909
|
+
if not context.format_by_arg:
|
|
910
|
+
context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
|
|
911
|
+
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
|
|
912
|
+
if not context.format_by_arg:
|
|
913
|
+
return
|
|
914
|
+
if not context.verified:
|
|
915
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
916
|
+
module_input, context.module_name,
|
|
917
|
+
MonitorConst.ACTV_IN)
|
|
918
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
919
|
+
module_output, context.module_name,
|
|
920
|
+
MonitorConst.ACTV_OUT)
|
|
921
|
+
context.verified = True
|
|
922
|
+
# expect output be tensor type
|
|
923
|
+
tbtag_tensor_map = {}
|
|
924
|
+
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
925
|
+
tbtag_tensor_map.update(
|
|
926
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
927
|
+
cared_input))
|
|
928
|
+
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
929
|
+
tbtag_tensor_map.update(
|
|
930
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
|
|
931
|
+
cared_output))
|
|
932
|
+
|
|
933
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
934
|
+
context.micro_step += 1
|
|
935
|
+
if context.micro_step == self.micro_batch_number:
|
|
936
|
+
context.micro_step = 0
|
|
937
|
+
return
|
|
938
|
+
|
|
939
|
+
def bwd_hook_fun(module, input_grad, output_grad):
|
|
940
|
+
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
941
|
+
if not context.struct:
|
|
942
|
+
context.struct = {
|
|
943
|
+
MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
|
|
944
|
+
MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
|
|
945
|
+
}
|
|
946
|
+
if self.print_struct:
|
|
947
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
948
|
+
return
|
|
949
|
+
if not context.format_by_arg:
|
|
950
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
|
|
951
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
|
|
952
|
+
if not context.format_by_arg:
|
|
953
|
+
return
|
|
954
|
+
if not context.verified:
|
|
955
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
956
|
+
input_grad, context.module_name,
|
|
957
|
+
MonitorConst.ACTVGRAD_IN)
|
|
958
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
959
|
+
output_grad, context.module_name,
|
|
960
|
+
MonitorConst.ACTVGRAD_OUT)
|
|
961
|
+
context.verified = True
|
|
962
|
+
|
|
963
|
+
tbtag_tensor_map = {}
|
|
964
|
+
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
965
|
+
tbtag_tensor_map.update(
|
|
966
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN,
|
|
967
|
+
cared_input_grad))
|
|
968
|
+
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
969
|
+
tbtag_tensor_map.update(
|
|
970
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
|
|
971
|
+
cared_output_grad))
|
|
972
|
+
|
|
973
|
+
if context.micro_step == 0 and context.actvgrad:
|
|
974
|
+
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
975
|
+
f"maybe something wrong happened. Now clear it.")
|
|
976
|
+
context.actvgrad.clear()
|
|
977
|
+
|
|
978
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
|
|
979
|
+
|
|
980
|
+
context.micro_step += 1
|
|
981
|
+
if context.micro_step == self.micro_batch_number:
|
|
982
|
+
context.micro_step = 0
|
|
983
|
+
return
|
|
984
|
+
|
|
985
|
+
if self.backward_only and self.forward_only:
|
|
986
|
+
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
987
|
+
|
|
988
|
+
hooked_count = 0
|
|
989
|
+
if self.xy_distribution or self.print_struct:
|
|
990
|
+
for module_name, submodule in module.named_modules():
|
|
991
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
992
|
+
if not name:
|
|
993
|
+
continue
|
|
994
|
+
if not self.backward_only:
|
|
995
|
+
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
996
|
+
self.handles['xy'].append(handle)
|
|
997
|
+
if not self.forward_only and not self.has_register_backward_hook(name, submodule):
|
|
998
|
+
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
999
|
+
self.handles['xy'].append(handle)
|
|
1000
|
+
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
1001
|
+
logger.info_on_rank_0(f"> {name} is monitored successfully")
|
|
1002
|
+
hooked_count += 1
|
|
1003
|
+
return hooked_count
|
|
1004
|
+
|
|
1005
|
+
def _patch_grad_sync(self):
|
|
1006
|
+
def patch_sync(sync_grad_func):
|
|
1007
|
+
def wrapper(bucket):
|
|
1008
|
+
grad_dict = {}
|
|
1009
|
+
bucket_params_id_list = [id(params) for params in bucket.params_list]
|
|
1010
|
+
for param, name in self.param2name.items():
|
|
1011
|
+
if id(param) not in bucket_params_id_list:
|
|
1012
|
+
continue
|
|
1013
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
1014
|
+
if grad is None:
|
|
1015
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
1016
|
+
continue
|
|
1017
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1018
|
+
if tag is None:
|
|
1019
|
+
continue
|
|
1020
|
+
grad_dict[tag] = grad
|
|
1021
|
+
self._register_param_call_id("sync_grad_func", tag)
|
|
1022
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1023
|
+
out = sync_grad_func(bucket)
|
|
1024
|
+
return out
|
|
1025
|
+
|
|
1026
|
+
return wrapper
|
|
1027
|
+
|
|
1028
|
+
try:
|
|
1029
|
+
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
1030
|
+
self.enable_megatron = True
|
|
1031
|
+
except ImportError:
|
|
1032
|
+
self.enable_megatron = False
|
|
1033
|
+
|
|
1034
|
+
if not self.wg_distribution:
|
|
1035
|
+
return
|
|
1036
|
+
|
|
1037
|
+
if self.enable_megatron:
|
|
1038
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
1039
|
+
else:
|
|
1040
|
+
self._hook_weights()
|
|
1041
|
+
|
|
1042
|
+
def _hook_weights(self):
|
|
1043
|
+
context = self.grad_context
|
|
1044
|
+
|
|
1045
|
+
@torch.no_grad
|
|
1046
|
+
def param_hook(*args, context_dict, param, key, name):
|
|
1047
|
+
param.micro_step += 1
|
|
1048
|
+
self._register_param_call_id("param_hook", key)
|
|
1049
|
+
if param.micro_step == self.micro_batch_number:
|
|
1050
|
+
param.micro_step = 0
|
|
1051
|
+
if self.params_have_main_grad:
|
|
1052
|
+
context_dict[key] = param.main_grad.clone()
|
|
1053
|
+
else:
|
|
1054
|
+
context_dict[key] = param.grad.clone()
|
|
1055
|
+
|
|
1056
|
+
for param, name in self.param2name.items():
|
|
1057
|
+
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
1058
|
+
setattr(param, 'micro_step', 0)
|
|
1059
|
+
param_tmp = param.expand_as(param)
|
|
1060
|
+
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
1061
|
+
handle = grad_acc.register_hook(
|
|
1062
|
+
partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
|
|
1063
|
+
self.grad_accs.append(grad_acc)
|
|
1064
|
+
self.handles['wgrads'].append(handle)
|
|
1065
|
+
|
|
1066
|
+
self.weight_hooked = True
|
|
1067
|
+
|
|
1068
|
+
def _register_param_call_id(self, hook_name: str, key: str):
|
|
1069
|
+
"""
|
|
1070
|
+
:param hook_name:
|
|
1071
|
+
:param key: str, '0:relu_0/output_grad'
|
|
1072
|
+
:return:
|
|
1073
|
+
"""
|
|
1074
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
1075
|
+
self.param_name_call_id[key] = self.call_id
|
|
1076
|
+
self.call_id += 1
|