mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,870 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import time
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
import uuid
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from functools import partial
|
|
22
|
+
|
|
23
|
+
import pytz
|
|
24
|
+
import torch
|
|
25
|
+
import torch.distributed as dist
|
|
26
|
+
from msprobe.core.common.const import MonitorConst
|
|
27
|
+
from msprobe.core.common.file_utils import load_json
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
30
|
+
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
31
|
+
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
32
|
+
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
33
|
+
get_process_group
|
|
34
|
+
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
35
|
+
from msprobe.pytorch.monitor.module_metric import get_metrics, write_metrics_base, get_summary_writer_tag_name, \
|
|
36
|
+
TensorMetrics, write_metrics_csv, squash_param_name
|
|
37
|
+
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
38
|
+
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon
|
|
39
|
+
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation
|
|
40
|
+
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
41
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
42
|
+
from torch.utils.hooks import BackwardHook
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
import torch_npu
|
|
46
|
+
except ImportError:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
50
|
+
if not torch_version_above_or_equal_2:
|
|
51
|
+
raise ValueError("monitor require torch>=2.0")
|
|
52
|
+
|
|
53
|
+
output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
|
|
54
|
+
|
|
55
|
+
FORMAT_MAPPING = {
|
|
56
|
+
MonitorConst.TENSORBOARD: (SummaryWriterWithAD, write_metrics_base),
|
|
57
|
+
MonitorConst.CSV: (CSVWriterWithAD, write_metrics_csv),
|
|
58
|
+
MonitorConst.API: (BaseWriterWithAD, write_metrics_base)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def param_is_not_tensor_parallel_duplicate(param, tp_group):
|
|
63
|
+
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
|
|
64
|
+
torch.distributed.get_rank(group=tp_group) == 0
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def param_is_data_parallel_duplicate(dp_group):
|
|
69
|
+
return torch.distributed.get_rank(group=dp_group) != 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ModuleHookContext:
|
|
73
|
+
def __init__(self, module_name) -> None:
|
|
74
|
+
self.step = 0
|
|
75
|
+
self.micro_step = 0
|
|
76
|
+
self.actv = defaultdict(dict)
|
|
77
|
+
self.actvgrad = []
|
|
78
|
+
self.module_name = module_name
|
|
79
|
+
self.struct = {}
|
|
80
|
+
self.format_by_arg = {}
|
|
81
|
+
self.verified = False
|
|
82
|
+
self.focused_in_col = 0
|
|
83
|
+
self.focused_out_col = 0
|
|
84
|
+
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
|
|
85
|
+
|
|
86
|
+
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
87
|
+
cared = target_config.get(self.module_name, self.struct)
|
|
88
|
+
if key_name in cared:
|
|
89
|
+
if isinstance(cared[key_name], dict):
|
|
90
|
+
# current cared is self.struct
|
|
91
|
+
config = cared[key_name].get('config')
|
|
92
|
+
self.format_by_arg[key_name] = config
|
|
93
|
+
else:
|
|
94
|
+
# current cared is target_config[self.module_name]
|
|
95
|
+
self.format_by_arg[key_name] = cared[key_name]
|
|
96
|
+
elif key_name in ['input', 'input_grad']:
|
|
97
|
+
self.ignore_in = True
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class OptimizerContext:
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
self.step = 0
|
|
103
|
+
self.param_effective_rank = defaultdict(float)
|
|
104
|
+
self.param_mg_direction = defaultdict(float)
|
|
105
|
+
self.param_adam_update = defaultdict()
|
|
106
|
+
self.param_adam_ratio = defaultdict()
|
|
107
|
+
self.param_weight_grad = defaultdict()
|
|
108
|
+
self.param_exp_avg = defaultdict()
|
|
109
|
+
self.exp_avg_metric = {}
|
|
110
|
+
self.param_exp_avg_sq = defaultdict()
|
|
111
|
+
self.exp_avg_sq_metric = {}
|
|
112
|
+
self.metric_dict = {}
|
|
113
|
+
self.param_metric = {}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CommunicationContext:
|
|
117
|
+
def __init__(self) -> None:
|
|
118
|
+
self.data = {}
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _agg(data):
|
|
122
|
+
aggregated_data = {}
|
|
123
|
+
for tag, op2tensorlist in data.items():
|
|
124
|
+
aggregated_data[tag] = {}
|
|
125
|
+
for op, tensorlist in op2tensorlist.items():
|
|
126
|
+
aggregated_data[tag][op] = op_aggregate(op, tensorlist)
|
|
127
|
+
return aggregated_data
|
|
128
|
+
|
|
129
|
+
def reset(self):
|
|
130
|
+
self.data = {}
|
|
131
|
+
|
|
132
|
+
def aggregate(self):
|
|
133
|
+
self.data = self._agg(self.data)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class GradContext:
|
|
137
|
+
def __init__(self) -> None:
|
|
138
|
+
self.pre = {}
|
|
139
|
+
self.post = {}
|
|
140
|
+
self.acc_metric = {}
|
|
141
|
+
self.acc = {}
|
|
142
|
+
self.actv = {}
|
|
143
|
+
|
|
144
|
+
def reset(self):
|
|
145
|
+
self.pre.clear()
|
|
146
|
+
self.post.clear()
|
|
147
|
+
self.acc_metric.clear()
|
|
148
|
+
self.acc.clear()
|
|
149
|
+
self.actv.clear()
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class TrainerMon:
|
|
153
|
+
tensor_metrics = TensorMetrics()
|
|
154
|
+
|
|
155
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
|
|
156
|
+
"""
|
|
157
|
+
opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
|
|
158
|
+
"""
|
|
159
|
+
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
160
|
+
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
161
|
+
self.optimizer_context = defaultdict(OptimizerContext)
|
|
162
|
+
self.cc_context = defaultdict(CommunicationContext)
|
|
163
|
+
self.grad_context = GradContext()
|
|
164
|
+
self.process_group = get_process_group(process_group)
|
|
165
|
+
self.params_have_main_grad = params_have_main_grad
|
|
166
|
+
self.opt_ty = opt_ty
|
|
167
|
+
self.config = load_json(config_file_path)
|
|
168
|
+
validate_config(self.config)
|
|
169
|
+
|
|
170
|
+
self.module_rank_list = self.config.get("module_ranks", [])
|
|
171
|
+
self.format = self.config.get('format', 'tensorboard')
|
|
172
|
+
self.eps = self.config.get('eps', 1e-8)
|
|
173
|
+
self.ops = self.config.get('ops', [])
|
|
174
|
+
self.ndigits = self.config.get('ndigits', 6)
|
|
175
|
+
self.all_xy = self.config.get('all_xy', False)
|
|
176
|
+
self.xy_distribution = self.config.get('xy_distribution', False)
|
|
177
|
+
self.forward_only = self.config.get('forward_only', False)
|
|
178
|
+
self.backward_only = self.config.get('backward_only', False)
|
|
179
|
+
self.ur_distribution = self.config.get('ur_distribution', False)
|
|
180
|
+
self.mv_distribution = self.config.get("mv_distribution", False)
|
|
181
|
+
self.wg_distribution = self.config.get("wg_distribution", False)
|
|
182
|
+
self.param_distribution = self.config.get("param_distribution", False)
|
|
183
|
+
self.mg_direction = self.config.get('mg_direction', False)
|
|
184
|
+
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
185
|
+
if not self.cc_distribution.get('enable', False):
|
|
186
|
+
self.cc_log_only = False
|
|
187
|
+
else:
|
|
188
|
+
self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
|
|
189
|
+
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
190
|
+
self.cc_logged_stack = defaultdict(set)
|
|
191
|
+
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
192
|
+
api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
193
|
+
api_register.redirect_api()
|
|
194
|
+
|
|
195
|
+
self.common_info()
|
|
196
|
+
|
|
197
|
+
alert_setting = self.config.get('alert', {"rules": []})
|
|
198
|
+
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
199
|
+
|
|
200
|
+
# 设置时区,使用 'UTC' 作为示例
|
|
201
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
202
|
+
|
|
203
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
204
|
+
unique_id = str(uuid.uuid4())[:8]
|
|
205
|
+
|
|
206
|
+
if dist.is_initialized():
|
|
207
|
+
rank = dist.get_rank()
|
|
208
|
+
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
209
|
+
pp_stage = dist.get_group_rank(self.process_group, rank)
|
|
210
|
+
group_mates = dist.get_process_group_ranks(self.process_group)
|
|
211
|
+
else:
|
|
212
|
+
rank = 0
|
|
213
|
+
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
214
|
+
pp_stage = 0
|
|
215
|
+
group_mates = [0]
|
|
216
|
+
self.rank = rank
|
|
217
|
+
|
|
218
|
+
# 初始化AnomalyData工厂
|
|
219
|
+
self.anomaly_data_factory = None
|
|
220
|
+
if alert_setting.get('dump', False):
|
|
221
|
+
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
222
|
+
|
|
223
|
+
if self.format not in FORMAT_MAPPING:
|
|
224
|
+
raise ValueError(f"Unsupported format: {self.format}")
|
|
225
|
+
writer, self.write_metrics = FORMAT_MAPPING[self.format]
|
|
226
|
+
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
227
|
+
|
|
228
|
+
if (rank in self.module_rank_list) or len(self.module_rank_list) == 0:
|
|
229
|
+
self.summary_writer = writer(
|
|
230
|
+
WriterInput(
|
|
231
|
+
tensorboard_dir,
|
|
232
|
+
self.alert_rules,
|
|
233
|
+
unique_id,
|
|
234
|
+
None,
|
|
235
|
+
self.anomaly_data_factory,
|
|
236
|
+
self.ndigits,
|
|
237
|
+
self.step_count_per_record
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
# 初始化anomaly detected文件目录
|
|
241
|
+
if self.anomaly_data_factory:
|
|
242
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank)
|
|
243
|
+
self.anomaly_data_writer.init_detected_json()
|
|
244
|
+
|
|
245
|
+
# A HeatmapVisualizer instance is associated with an image
|
|
246
|
+
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
247
|
+
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
248
|
+
self.micro_batch_number = 1
|
|
249
|
+
|
|
250
|
+
self.model = None
|
|
251
|
+
self.weight_hooked = False
|
|
252
|
+
self.optimizer_hooked = False
|
|
253
|
+
self.param_registered = False
|
|
254
|
+
self.vpp = False
|
|
255
|
+
self.dp_group = None
|
|
256
|
+
self.tp_group = None
|
|
257
|
+
self.enable_megatron = False
|
|
258
|
+
|
|
259
|
+
self.param2name = defaultdict(str)
|
|
260
|
+
self.name2index = defaultdict()
|
|
261
|
+
self.name2indices = defaultdict()
|
|
262
|
+
self.name2param = {}
|
|
263
|
+
self.param_name_call_id = {}
|
|
264
|
+
self.duplicate_param = {}
|
|
265
|
+
self.name2tag = {}
|
|
266
|
+
self.call_id = 0
|
|
267
|
+
self.grad_accs = []
|
|
268
|
+
self.handles = defaultdict(list)
|
|
269
|
+
|
|
270
|
+
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
271
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
272
|
+
self.struct_printed = False
|
|
273
|
+
self.module_struct = {}
|
|
274
|
+
|
|
275
|
+
def __del__(self):
|
|
276
|
+
if hasattr(self, "summary_writer"):
|
|
277
|
+
self.summary_writer.close()
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def ops(self):
|
|
281
|
+
return self._ops
|
|
282
|
+
|
|
283
|
+
@ops.setter
|
|
284
|
+
def ops(self, value):
|
|
285
|
+
self._ops = validate_ops(value)
|
|
286
|
+
|
|
287
|
+
@staticmethod
|
|
288
|
+
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
289
|
+
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
293
|
+
rank = None
|
|
294
|
+
if dist.is_initialized():
|
|
295
|
+
rank = dist.get_rank()
|
|
296
|
+
if (rank not in rank_list) and len(rank_list) != 0:
|
|
297
|
+
return
|
|
298
|
+
TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def build_tbtag_tensor_map(module_name, tag, tensor):
|
|
302
|
+
metrics = {}
|
|
303
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
304
|
+
key = get_summary_writer_tag_name(module_name, tag, rank)
|
|
305
|
+
if torch.is_tensor(tensor):
|
|
306
|
+
metrics[key] = tensor
|
|
307
|
+
return metrics
|
|
308
|
+
|
|
309
|
+
@staticmethod
|
|
310
|
+
def generate_cc_metrics(cc_name, cc_tensor):
|
|
311
|
+
metrics = defaultdict(dict)
|
|
312
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
313
|
+
for op, tag2tensor in cc_tensor.data.items():
|
|
314
|
+
for tag, tensor in tag2tensor.items():
|
|
315
|
+
key = get_summary_writer_tag_name(cc_name, tag, rank)
|
|
316
|
+
metrics[op].update({key: tensor})
|
|
317
|
+
cc_tensor.reset()
|
|
318
|
+
return metrics
|
|
319
|
+
|
|
320
|
+
def common_info(self):
|
|
321
|
+
if not self.xy_distribution:
|
|
322
|
+
logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
|
|
323
|
+
if self.forward_only:
|
|
324
|
+
logger.info_on_rank_0("> only module forward is monitored. ")
|
|
325
|
+
if not self.ur_distribution:
|
|
326
|
+
logger.info_on_rank_0("> update vector and ratio vector of adam is not monitored. ")
|
|
327
|
+
if not self.mv_distribution:
|
|
328
|
+
logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
|
|
329
|
+
if not self.wg_distribution:
|
|
330
|
+
logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
|
|
331
|
+
if not self.mg_direction:
|
|
332
|
+
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
|
|
333
|
+
if not self.cc_distribution.get('enable', False):
|
|
334
|
+
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
335
|
+
if not self.opt_ty:
|
|
336
|
+
if self.ur_distribution:
|
|
337
|
+
raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
|
|
338
|
+
if self.mv_distribution:
|
|
339
|
+
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
|
|
340
|
+
|
|
341
|
+
def hook_modules(self, model: torch.nn.Module, grad_acc_steps):
|
|
342
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
343
|
+
return
|
|
344
|
+
|
|
345
|
+
if not isinstance(model, list):
|
|
346
|
+
model = [model]
|
|
347
|
+
self.model = model
|
|
348
|
+
self._register_param_name(model)
|
|
349
|
+
|
|
350
|
+
self.micro_batch_number = grad_acc_steps
|
|
351
|
+
|
|
352
|
+
targets = self.config['targets']
|
|
353
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
|
|
354
|
+
for key in module_in_all_stage:
|
|
355
|
+
struct = targets.pop(key)
|
|
356
|
+
targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))})
|
|
357
|
+
|
|
358
|
+
hooked_count = 0
|
|
359
|
+
for vpp_stage, model_chunk in enumerate(model):
|
|
360
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
361
|
+
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
362
|
+
'targets'].keys()
|
|
363
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
364
|
+
|
|
365
|
+
logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.")
|
|
366
|
+
|
|
367
|
+
def clone_if_tensor(args):
|
|
368
|
+
if isinstance(args, tuple):
|
|
369
|
+
return tuple([clone_if_tensor(arg) for arg in args])
|
|
370
|
+
elif isinstance(args, torch.Tensor):
|
|
371
|
+
return args.clone()
|
|
372
|
+
else:
|
|
373
|
+
return args
|
|
374
|
+
|
|
375
|
+
@torch.no_grad
|
|
376
|
+
def wrap_hook_setup(setup):
|
|
377
|
+
def wrapped_setup(*args, **kwargs):
|
|
378
|
+
args = setup(*args, **kwargs)
|
|
379
|
+
args = clone_if_tensor(args)
|
|
380
|
+
return args
|
|
381
|
+
|
|
382
|
+
return wrapped_setup
|
|
383
|
+
|
|
384
|
+
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
385
|
+
|
|
386
|
+
if not self.optimizer_hooked:
|
|
387
|
+
self.hook_optimizer()
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
def generate_param_metrics(self, opt_context):
|
|
391
|
+
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
392
|
+
|
|
393
|
+
def generate_mv_metrics(self, opt_context):
|
|
394
|
+
if not self.mv_distribution:
|
|
395
|
+
return
|
|
396
|
+
opt_context.exp_avg_metric = {}
|
|
397
|
+
opt_context.exp_avg_sq_metric = {}
|
|
398
|
+
m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
|
|
399
|
+
v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
|
|
400
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
401
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
402
|
+
|
|
403
|
+
def generate_wgrad_metrics(self):
|
|
404
|
+
if not self.wg_distribution:
|
|
405
|
+
return {}, {}
|
|
406
|
+
|
|
407
|
+
if self.weight_hooked:
|
|
408
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
409
|
+
|
|
410
|
+
grad_dict = {}
|
|
411
|
+
for param, name in self.param2name.items():
|
|
412
|
+
if self.duplicate_param.get(name, False):
|
|
413
|
+
continue
|
|
414
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
415
|
+
if grad is None:
|
|
416
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
417
|
+
continue
|
|
418
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
419
|
+
grad_dict[tag] = grad
|
|
420
|
+
|
|
421
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
422
|
+
return self.grad_context.post, self.grad_context.pre
|
|
423
|
+
|
|
424
|
+
def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None):
|
|
425
|
+
"""External interface"""
|
|
426
|
+
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
427
|
+
self.hook_optimizer(optimizer)
|
|
428
|
+
self.micro_batch_number = grad_acc_steps
|
|
429
|
+
|
|
430
|
+
self.dp_group = dp_group
|
|
431
|
+
self.tp_group = tp_group
|
|
432
|
+
|
|
433
|
+
self._register_param_name(model)
|
|
434
|
+
self._patch_grad_sync()
|
|
435
|
+
self.hook_modules(model, grad_acc_steps)
|
|
436
|
+
|
|
437
|
+
def generate_param_map(self, tag, param_tensor):
|
|
438
|
+
metrics = {}
|
|
439
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
440
|
+
for name in self.param2name.values():
|
|
441
|
+
key = get_summary_writer_tag_name(name, tag, rank)
|
|
442
|
+
if name not in param_tensor or param_tensor[name] is None:
|
|
443
|
+
continue
|
|
444
|
+
metrics[key] = param_tensor[name]
|
|
445
|
+
return metrics
|
|
446
|
+
|
|
447
|
+
def generate_xy_metrics(self):
|
|
448
|
+
actv = {}
|
|
449
|
+
for fwd_context in self.module_fwd_hook_context_by_module.values():
|
|
450
|
+
actv.update(fwd_context.actv)
|
|
451
|
+
|
|
452
|
+
actv_grad = self.grad_context.actv
|
|
453
|
+
|
|
454
|
+
return actv, actv_grad
|
|
455
|
+
|
|
456
|
+
def reload_xy(self, xy_distribution=False):
|
|
457
|
+
self.xy_distribution = xy_distribution
|
|
458
|
+
|
|
459
|
+
for handle in self.handles['xy']:
|
|
460
|
+
handle.remove()
|
|
461
|
+
self.handles['xy'].clear()
|
|
462
|
+
self.hook_modules(self.model, self.micro_batch_number)
|
|
463
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
464
|
+
fwd_context.actv.clear()
|
|
465
|
+
|
|
466
|
+
def write_adhoc_check(self, step):
|
|
467
|
+
TrainerMon.tensor_metrics.flush(self.summary_writer)
|
|
468
|
+
|
|
469
|
+
def write_xy_tb(self, step):
|
|
470
|
+
if not self.xy_distribution:
|
|
471
|
+
return
|
|
472
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
473
|
+
if len(fwd_context.actv) == 0:
|
|
474
|
+
continue
|
|
475
|
+
self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv')
|
|
476
|
+
fwd_context.actv.clear()
|
|
477
|
+
if self.grad_context.actv:
|
|
478
|
+
self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad')
|
|
479
|
+
|
|
480
|
+
def write_param_tb(self, opt_context):
|
|
481
|
+
if not self.param_distribution:
|
|
482
|
+
return
|
|
483
|
+
self.write_metrics(self.ops, self.summary_writer, opt_context.param_metric, opt_context.step, 'param')
|
|
484
|
+
|
|
485
|
+
def write_mv_tb(self, opt_context):
|
|
486
|
+
if not self.mv_distribution:
|
|
487
|
+
return
|
|
488
|
+
self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_metric,
|
|
489
|
+
opt_context.step, 'exp_avg')
|
|
490
|
+
self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric,
|
|
491
|
+
opt_context.step, 'exp_avg_sq')
|
|
492
|
+
|
|
493
|
+
def write_grad_tb(self, step):
|
|
494
|
+
if not self.wg_distribution:
|
|
495
|
+
return
|
|
496
|
+
|
|
497
|
+
if self.enable_megatron:
|
|
498
|
+
self.write_metrics(self.ops, self.summary_writer, self.grad_context.pre, step, 'grad_unreduced')
|
|
499
|
+
else:
|
|
500
|
+
self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
501
|
+
self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced')
|
|
502
|
+
|
|
503
|
+
def hook_optimizer(self, optimizer=None):
|
|
504
|
+
# in DDP by default use params_have_main_grad
|
|
505
|
+
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
506
|
+
context = self.optimizer_context[optimizer]
|
|
507
|
+
if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY:
|
|
508
|
+
if context.step == 0:
|
|
509
|
+
self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name,
|
|
510
|
+
self.name2index)
|
|
511
|
+
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
|
|
512
|
+
self.name2indices)
|
|
513
|
+
self.param2name = mv_result.grad
|
|
514
|
+
else:
|
|
515
|
+
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
516
|
+
context.param_exp_avg = mv_result.exp_avg
|
|
517
|
+
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
518
|
+
context.param_adam_update = mv_result.update
|
|
519
|
+
context.param_adam_ratio = mv_result.ratio
|
|
520
|
+
|
|
521
|
+
if (self.print_struct and not all(value == {} for value in self.module_struct.values())
|
|
522
|
+
and not self.struct_printed):
|
|
523
|
+
self._smallest_rank_print("> module struct:")
|
|
524
|
+
self._smallest_rank_print(json.dumps(self.module_struct))
|
|
525
|
+
self.struct_printed = True
|
|
526
|
+
if not self.cc_log_only:
|
|
527
|
+
raise Exception("exit after first step when print model struct")
|
|
528
|
+
if self.cc_log_only and context.step > 0:
|
|
529
|
+
self._smallest_rank_print("> Used communication ops and corresponding stack")
|
|
530
|
+
self._smallest_rank_print(
|
|
531
|
+
json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}))
|
|
532
|
+
raise Exception("exit after first step when print cc stack")
|
|
533
|
+
|
|
534
|
+
self.generate_wgrad_metrics()
|
|
535
|
+
self.generate_mv_metrics(context)
|
|
536
|
+
self.generate_param_metrics(context)
|
|
537
|
+
|
|
538
|
+
tbtag_tensor_map = {}
|
|
539
|
+
if self.mg_direction:
|
|
540
|
+
for param, name in self.param2name.items():
|
|
541
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
542
|
+
if grad is None:
|
|
543
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
544
|
+
continue
|
|
545
|
+
if context.step == 0:
|
|
546
|
+
same_direction_ratio = torch.tensor(1.)
|
|
547
|
+
else:
|
|
548
|
+
same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name])
|
|
549
|
+
context.param_mg_direction[name] = same_direction_ratio
|
|
550
|
+
tbtag_tensor_map.update(self.generate_param_map('mg_direction', context.param_mg_direction))
|
|
551
|
+
|
|
552
|
+
metric_dict = {}
|
|
553
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, metric_dict)
|
|
554
|
+
for cc in self.cc_context.values():
|
|
555
|
+
cc.aggregate()
|
|
556
|
+
metric_dict.update(cc.data)
|
|
557
|
+
cc.reset()
|
|
558
|
+
|
|
559
|
+
if not metric_dict:
|
|
560
|
+
return
|
|
561
|
+
context.metric_dict = metric_dict
|
|
562
|
+
return
|
|
563
|
+
|
|
564
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
565
|
+
context = self.optimizer_context[optimizer]
|
|
566
|
+
rank = dist.get_rank() if dist.is_initialized() else None
|
|
567
|
+
|
|
568
|
+
if self.anomaly_data_factory:
|
|
569
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
570
|
+
self.write_xy_tb(context.step)
|
|
571
|
+
self.write_grad_tb(context.step)
|
|
572
|
+
self.write_mv_tb(context)
|
|
573
|
+
self.write_param_tb(context)
|
|
574
|
+
self.write_adhoc_check(context.step)
|
|
575
|
+
|
|
576
|
+
if self.ur_distribution:
|
|
577
|
+
for param_name, _ in context.param_adam_update.items():
|
|
578
|
+
self.update_heatmap_visualizer[param_name].visualize(
|
|
579
|
+
get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer)
|
|
580
|
+
for param_name, _ in context.param_adam_ratio.items():
|
|
581
|
+
self.ratio_heatmap_visualizer[param_name].visualize(
|
|
582
|
+
get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer)
|
|
583
|
+
|
|
584
|
+
if context.metric_dict:
|
|
585
|
+
self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other')
|
|
586
|
+
context.metric_dict.clear()
|
|
587
|
+
context.step += 1
|
|
588
|
+
if self.anomaly_data_factory:
|
|
589
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
590
|
+
self.summary_writer.clear_anomalies()
|
|
591
|
+
self.call_id = 0
|
|
592
|
+
return
|
|
593
|
+
|
|
594
|
+
def patch_step(func, optimizer):
|
|
595
|
+
def wrapper(*args, **kwargs):
|
|
596
|
+
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
597
|
+
out = func(*args, **kwargs)
|
|
598
|
+
optimizer_post_step_hook(optimizer, args, kwargs)
|
|
599
|
+
return out
|
|
600
|
+
|
|
601
|
+
return wrapper
|
|
602
|
+
|
|
603
|
+
if self.optimizer_hooked:
|
|
604
|
+
return
|
|
605
|
+
|
|
606
|
+
if optimizer:
|
|
607
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
608
|
+
|
|
609
|
+
else:
|
|
610
|
+
if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
|
|
611
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
612
|
+
register_optimizer_step_post_hook(optimizer_post_step_hook)
|
|
613
|
+
self.optimizer_hooked = True
|
|
614
|
+
return
|
|
615
|
+
|
|
616
|
+
def _smallest_rank_print(self, msg):
|
|
617
|
+
if dist.is_initialized():
|
|
618
|
+
if self.module_rank_list:
|
|
619
|
+
if dist.get_rank() == min(self.module_rank_list):
|
|
620
|
+
logger.info(msg)
|
|
621
|
+
else:
|
|
622
|
+
if dist.get_rank() == 0:
|
|
623
|
+
logger.info(msg)
|
|
624
|
+
else:
|
|
625
|
+
logger.info(msg)
|
|
626
|
+
|
|
627
|
+
def _is_target_param(self, param_name, param, prefix):
|
|
628
|
+
squash_name = prefix + squash_param_name(param_name)
|
|
629
|
+
name = prefix + param_name
|
|
630
|
+
for target in self.config['targets'].keys():
|
|
631
|
+
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
632
|
+
setattr(param, "zero_out_wgrad", True)
|
|
633
|
+
return True
|
|
634
|
+
|
|
635
|
+
return False
|
|
636
|
+
|
|
637
|
+
def _register_chunk(self, model_chunk, prefix):
|
|
638
|
+
for index, (param_name, param) in enumerate(model_chunk.named_parameters()):
|
|
639
|
+
if not param.requires_grad:
|
|
640
|
+
continue
|
|
641
|
+
if self._is_target_param(param_name, param, prefix):
|
|
642
|
+
name = prefix + squash_param_name(param_name)
|
|
643
|
+
if name in self.param2name.values():
|
|
644
|
+
logger.error(f'same name {name} for different param. Current param is {param_name}. \
|
|
645
|
+
May be error of squash_param_name')
|
|
646
|
+
raise Exception("param with same name will be overwritten.")
|
|
647
|
+
self.param2name[param] = name
|
|
648
|
+
self.name2param[name] = param
|
|
649
|
+
self.name2index[name] = index
|
|
650
|
+
|
|
651
|
+
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
652
|
+
self.duplicate_param[name] = True
|
|
653
|
+
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
654
|
+
self.duplicate_param[name] = True
|
|
655
|
+
self.name2tag[name] = {}
|
|
656
|
+
self.name2tag[name][MonitorConst.PRE_GRAD] = get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD,
|
|
657
|
+
self.rank)
|
|
658
|
+
self.name2tag[name][MonitorConst.POST_GRAD] = get_summary_writer_tag_name(name, MonitorConst.POST_GRAD,
|
|
659
|
+
self.rank)
|
|
660
|
+
|
|
661
|
+
def _register_param_name(self, model):
|
|
662
|
+
if self.param_registered:
|
|
663
|
+
return
|
|
664
|
+
|
|
665
|
+
if not isinstance(model, list):
|
|
666
|
+
model = [model]
|
|
667
|
+
|
|
668
|
+
if len(model) > 1:
|
|
669
|
+
self.vpp = True
|
|
670
|
+
self._smallest_rank_print('vpp enabled')
|
|
671
|
+
|
|
672
|
+
for vpp_stage, model_chunk in enumerate(model):
|
|
673
|
+
prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
674
|
+
self._register_chunk(model_chunk, prefix)
|
|
675
|
+
|
|
676
|
+
self.param_registered = True
|
|
677
|
+
|
|
678
|
+
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
679
|
+
if self.all_xy or self.print_struct:
|
|
680
|
+
return vpp_stage + squash_param_name(module_name)
|
|
681
|
+
for pattern in [
|
|
682
|
+
vpp_stage + squash_param_name(module_name),
|
|
683
|
+
vpp_stage + module_name,
|
|
684
|
+
]:
|
|
685
|
+
if pattern in targets:
|
|
686
|
+
return pattern
|
|
687
|
+
return ""
|
|
688
|
+
|
|
689
|
+
def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''):
|
|
690
|
+
if '_modules' not in module.__dict__:
|
|
691
|
+
# nothing to hook
|
|
692
|
+
return 0
|
|
693
|
+
|
|
694
|
+
def fwd_hook_fun(module, module_input, module_output, name):
|
|
695
|
+
if is_recomputation():
|
|
696
|
+
return
|
|
697
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
698
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
699
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
700
|
+
if not context.struct:
|
|
701
|
+
context.struct = {MonitorConst.ACTV_IN: get_param_struct(module_input),
|
|
702
|
+
MonitorConst.ACTV_OUT: get_param_struct(module_output)}
|
|
703
|
+
if self.print_struct:
|
|
704
|
+
if context.module_name not in self.module_struct:
|
|
705
|
+
self.module_struct[context.module_name] = {}
|
|
706
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
707
|
+
return
|
|
708
|
+
if not module.training:
|
|
709
|
+
return
|
|
710
|
+
if not context.format_by_arg:
|
|
711
|
+
context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets'])
|
|
712
|
+
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets'])
|
|
713
|
+
if not context.format_by_arg:
|
|
714
|
+
return
|
|
715
|
+
if not context.verified:
|
|
716
|
+
if not context.ignore_in:
|
|
717
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
718
|
+
module_input, context.module_name,
|
|
719
|
+
MonitorConst.ACTV_IN)
|
|
720
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
721
|
+
module_output, context.module_name,
|
|
722
|
+
MonitorConst.ACTV_OUT)
|
|
723
|
+
context.verified = True
|
|
724
|
+
# expect output be tensor type
|
|
725
|
+
tbtag_tensor_map = {}
|
|
726
|
+
if not context.ignore_in:
|
|
727
|
+
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
728
|
+
tbtag_tensor_map.update(
|
|
729
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
730
|
+
cared_input))
|
|
731
|
+
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
732
|
+
tbtag_tensor_map.update(
|
|
733
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
|
|
734
|
+
cared_output))
|
|
735
|
+
|
|
736
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
737
|
+
|
|
738
|
+
context.micro_step += 1
|
|
739
|
+
if context.micro_step == self.micro_batch_number:
|
|
740
|
+
context.micro_step = 0
|
|
741
|
+
context.step += 1
|
|
742
|
+
return
|
|
743
|
+
|
|
744
|
+
def bwd_hook_fun(module, input_grad, output_grad):
|
|
745
|
+
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
746
|
+
if not context.struct:
|
|
747
|
+
context.struct = {MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
|
|
748
|
+
MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)}
|
|
749
|
+
if self.print_struct:
|
|
750
|
+
if context.module_name not in self.module_struct:
|
|
751
|
+
self.module_struct[context.module_name] = {}
|
|
752
|
+
self.module_struct[context.module_name].update(context.struct)
|
|
753
|
+
return
|
|
754
|
+
if not context.format_by_arg:
|
|
755
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets'])
|
|
756
|
+
context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets'])
|
|
757
|
+
if not context.format_by_arg:
|
|
758
|
+
return
|
|
759
|
+
if not context.verified:
|
|
760
|
+
if not context.ignore_in:
|
|
761
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
762
|
+
input_grad, context.module_name,
|
|
763
|
+
MonitorConst.ACTVGRAD_IN)
|
|
764
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
765
|
+
output_grad, context.module_name,
|
|
766
|
+
MonitorConst.ACTVGRAD_OUT)
|
|
767
|
+
context.verified = True
|
|
768
|
+
|
|
769
|
+
tbtag_tensor_map = {}
|
|
770
|
+
if not context.ignore_in:
|
|
771
|
+
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
772
|
+
tbtag_tensor_map.update(
|
|
773
|
+
self.build_tbtag_tensor_map(
|
|
774
|
+
f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
|
|
775
|
+
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
776
|
+
tbtag_tensor_map.update(
|
|
777
|
+
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
|
|
778
|
+
cared_output_grad))
|
|
779
|
+
|
|
780
|
+
if context.micro_step == 0 and context.actvgrad:
|
|
781
|
+
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
782
|
+
f"maybe something wrong happened. Now clear it.")
|
|
783
|
+
context.actvgrad.clear()
|
|
784
|
+
|
|
785
|
+
get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
|
|
786
|
+
|
|
787
|
+
context.micro_step += 1
|
|
788
|
+
if context.micro_step == self.micro_batch_number:
|
|
789
|
+
context.micro_step = 0
|
|
790
|
+
context.step += 1
|
|
791
|
+
return
|
|
792
|
+
|
|
793
|
+
if self.backward_only and self.forward_only:
|
|
794
|
+
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
795
|
+
|
|
796
|
+
hooked_count = 0
|
|
797
|
+
if self.xy_distribution or self.print_struct:
|
|
798
|
+
for module_name, submodule in module.named_modules():
|
|
799
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
800
|
+
if not name:
|
|
801
|
+
continue
|
|
802
|
+
if not self.backward_only:
|
|
803
|
+
handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name))
|
|
804
|
+
self.handles['xy'].append(handle)
|
|
805
|
+
if not self.forward_only:
|
|
806
|
+
handle = submodule.register_full_backward_hook(bwd_hook_fun)
|
|
807
|
+
self.handles['xy'].append(handle)
|
|
808
|
+
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
809
|
+
logger.info_on_rank_0(f"> {name} is monitored successfully")
|
|
810
|
+
hooked_count += 1
|
|
811
|
+
return hooked_count
|
|
812
|
+
|
|
813
|
+
def _patch_grad_sync(self):
|
|
814
|
+
def patch_sync(sync_grad_func):
|
|
815
|
+
def wrapper(bucket):
|
|
816
|
+
grad_dict = {}
|
|
817
|
+
for param, name in self.param2name.items():
|
|
818
|
+
if param not in bucket.params_list:
|
|
819
|
+
continue
|
|
820
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
821
|
+
if grad is None:
|
|
822
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
823
|
+
continue
|
|
824
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
825
|
+
if tag is None:
|
|
826
|
+
continue
|
|
827
|
+
grad_dict[tag] = grad
|
|
828
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
829
|
+
out = sync_grad_func(bucket)
|
|
830
|
+
return out
|
|
831
|
+
|
|
832
|
+
return wrapper
|
|
833
|
+
|
|
834
|
+
try:
|
|
835
|
+
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
836
|
+
self.enable_megatron = True
|
|
837
|
+
except ImportError:
|
|
838
|
+
self.enable_megatron = False
|
|
839
|
+
|
|
840
|
+
if self.enable_megatron:
|
|
841
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
842
|
+
else:
|
|
843
|
+
self._hook_weights()
|
|
844
|
+
|
|
845
|
+
def _hook_weights(self):
|
|
846
|
+
context = self.grad_context
|
|
847
|
+
|
|
848
|
+
@torch.no_grad
|
|
849
|
+
def param_hook(*args, context_dict, param, key, name):
|
|
850
|
+
param.micro_step += 1
|
|
851
|
+
self.param_name_call_id[name] = self.call_id
|
|
852
|
+
self.call_id += 1
|
|
853
|
+
if param.micro_step == self.micro_batch_number:
|
|
854
|
+
param.micro_step = 0
|
|
855
|
+
if self.params_have_main_grad:
|
|
856
|
+
context_dict[key] = param.main_grad.clone()
|
|
857
|
+
else:
|
|
858
|
+
context_dict[key] = param.grad.clone()
|
|
859
|
+
|
|
860
|
+
for param, name in self.param2name.items():
|
|
861
|
+
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
862
|
+
setattr(param, 'micro_step', 0)
|
|
863
|
+
param_tmp = param.expand_as(param)
|
|
864
|
+
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
|
865
|
+
handle = grad_acc.register_hook(
|
|
866
|
+
partial(param_hook, context_dict=context.acc, param=param, key=key, name=name))
|
|
867
|
+
self.grad_accs.append(grad_acc)
|
|
868
|
+
self.handles['wgrads'].append(handle)
|
|
869
|
+
|
|
870
|
+
self.weight_hooked = True
|