mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -20,21 +20,24 @@ from collections import defaultdict
|
|
|
20
20
|
from datetime import datetime
|
|
21
21
|
|
|
22
22
|
import pytz
|
|
23
|
-
import
|
|
23
|
+
import pandas as pd
|
|
24
|
+
import mindspore
|
|
24
25
|
from mindspore import Tensor, mint
|
|
25
26
|
from mindspore import nn, _no_grad
|
|
26
|
-
from mindspore.communication import get_rank
|
|
27
27
|
|
|
28
28
|
from msprobe.core.common.log import logger
|
|
29
|
-
from msprobe.core.common.const import MonitorConst
|
|
29
|
+
from msprobe.core.common.const import MonitorConst, Const
|
|
30
30
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
31
|
+
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
33
|
+
from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
|
|
31
34
|
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
|
|
32
|
-
is_skip_step, get_metrics,
|
|
33
|
-
from msprobe.mindspore.monitor.
|
|
34
|
-
from msprobe.mindspore.monitor.
|
|
35
|
-
|
|
36
|
-
from msprobe.
|
|
37
|
-
|
|
35
|
+
is_skip_step, get_metrics, get_target_output_dir
|
|
36
|
+
from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
|
|
37
|
+
from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
38
|
+
from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
|
|
39
|
+
from msprobe.core.common.file_utils import write_df_to_csv
|
|
40
|
+
from msprobe.core.common.utils import analyze_api_call_stack
|
|
38
41
|
|
|
39
42
|
FORMAT_MAPPING = {
|
|
40
43
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -88,24 +91,7 @@ class ModuleHookContext:
|
|
|
88
91
|
self.actvgrad = []
|
|
89
92
|
self.module_name = module_name
|
|
90
93
|
self.struct = {}
|
|
91
|
-
self.
|
|
92
|
-
self.verified = False
|
|
93
|
-
self.focused_in_col = 0
|
|
94
|
-
self.focused_out_col = 0
|
|
95
|
-
self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
|
|
96
|
-
|
|
97
|
-
def set_format_by_arg(self, key_name: str, target_config: dict):
|
|
98
|
-
cared = target_config.get(self.module_name, self.struct)
|
|
99
|
-
if key_name in cared:
|
|
100
|
-
if isinstance(cared[key_name], dict):
|
|
101
|
-
# current cared is self.struct
|
|
102
|
-
config = cared[key_name].get('config')
|
|
103
|
-
self.format_by_arg[key_name] = config
|
|
104
|
-
else:
|
|
105
|
-
# current cared is target_config[self.module_name]
|
|
106
|
-
self.format_by_arg[key_name] = cared[key_name]
|
|
107
|
-
elif key_name in ['input', 'input_grad']:
|
|
108
|
-
self.ignore_in = True
|
|
94
|
+
self.stack = ""
|
|
109
95
|
|
|
110
96
|
def reset(self):
|
|
111
97
|
self.actv.clear()
|
|
@@ -186,6 +172,7 @@ class TrainerMon:
|
|
|
186
172
|
self.config_file_path = config_file_path
|
|
187
173
|
self.process_group = process_group
|
|
188
174
|
self.params_have_main_grad = params_have_main_grad
|
|
175
|
+
self.is_mindtorch = is_mindtorch()
|
|
189
176
|
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
190
177
|
self.config = load_json(config_file_path)
|
|
191
178
|
validate_config(self.config)
|
|
@@ -218,6 +205,7 @@ class TrainerMon:
|
|
|
218
205
|
self.dp_group = None
|
|
219
206
|
self.tp_group = None
|
|
220
207
|
self.micro_batch_number = 1
|
|
208
|
+
self.optimizer_mon = None
|
|
221
209
|
|
|
222
210
|
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
223
211
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
@@ -240,6 +228,8 @@ class TrainerMon:
|
|
|
240
228
|
self.optimizer_hooked = False
|
|
241
229
|
self.param_registered = False
|
|
242
230
|
self.struct_printed = False
|
|
231
|
+
self.pre_step_hooks = []
|
|
232
|
+
self.post_step_hooks = []
|
|
243
233
|
|
|
244
234
|
# 动静态区分
|
|
245
235
|
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
@@ -276,6 +266,9 @@ class TrainerMon:
|
|
|
276
266
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
277
267
|
self.mg_direction = self.config.get('mg_direction', False) # main grad direction
|
|
278
268
|
self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
|
|
269
|
+
self.stack_info = self.config.get('stack_info', False)
|
|
270
|
+
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
|
|
271
|
+
|
|
279
272
|
if not self.cc_distribution.get('enable', False):
|
|
280
273
|
self.cc_log_only = False
|
|
281
274
|
else:
|
|
@@ -296,18 +289,25 @@ class TrainerMon:
|
|
|
296
289
|
if self.format not in FORMAT_MAPPING:
|
|
297
290
|
logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
298
291
|
self.format = MonitorConst.CSV
|
|
299
|
-
writer = FORMAT_MAPPING[self.format]
|
|
300
292
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
301
|
-
self.
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
293
|
+
if not self.module_rank_list or (self.rank in self.module_rank_list):
|
|
294
|
+
writer = FORMAT_MAPPING[self.format]
|
|
295
|
+
self.summary_writer = writer(
|
|
296
|
+
WriterInput(
|
|
297
|
+
self.tensorboard_dir,
|
|
298
|
+
self.alert_rules,
|
|
299
|
+
self.unique_id,
|
|
300
|
+
self.anomaly_data_factory,
|
|
301
|
+
self.ndigits,
|
|
302
|
+
self.step_count_per_record
|
|
303
|
+
)
|
|
309
304
|
)
|
|
310
|
-
|
|
305
|
+
|
|
306
|
+
# 初始化anomaly detected文件目录
|
|
307
|
+
if self.anomaly_data_factory:
|
|
308
|
+
self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"),
|
|
309
|
+
self.rank)
|
|
310
|
+
self.anomaly_data_writer.init_detected_json()
|
|
311
311
|
|
|
312
312
|
def common_info(self):
|
|
313
313
|
if not self.xy_distribution:
|
|
@@ -339,6 +339,7 @@ class TrainerMon:
|
|
|
339
339
|
self.micro_batch_number = grad_acc_steps
|
|
340
340
|
self.dp_group = dp_group
|
|
341
341
|
self.tp_group = tp_group
|
|
342
|
+
self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
342
343
|
self.hook_step_final(optimizer)
|
|
343
344
|
if not isinstance(model, list):
|
|
344
345
|
model = [model]
|
|
@@ -359,16 +360,28 @@ class TrainerMon:
|
|
|
359
360
|
context.step - self.start_step) % self.step_interval == 0)
|
|
360
361
|
if module_rank_valid and step_condition:
|
|
361
362
|
self.has_collect_times += 1
|
|
363
|
+
|
|
364
|
+
if self.anomaly_data_factory:
|
|
365
|
+
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
362
366
|
self.write_xy_tb(context.step)
|
|
363
367
|
self.write_grad_tb(context.step)
|
|
364
368
|
self.write_mv_tb(context)
|
|
365
369
|
self.write_param_tb(context)
|
|
370
|
+
if self.stack_info:
|
|
371
|
+
self.write_stack_info()
|
|
372
|
+
self.stack_info = False
|
|
373
|
+
for handle in self.handles["stack"]:
|
|
374
|
+
handle.remove()
|
|
375
|
+
self.handles["stack"].clear()
|
|
366
376
|
|
|
367
377
|
if context.metric_dict:
|
|
368
378
|
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
369
379
|
context.metric_dict.clear()
|
|
370
380
|
|
|
381
|
+
if self.anomaly_data_factory:
|
|
382
|
+
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
371
383
|
self.summary_writer.clear_anomalies()
|
|
384
|
+
|
|
372
385
|
self.call_id = 0
|
|
373
386
|
self.param_name_call_id.clear()
|
|
374
387
|
|
|
@@ -378,7 +391,23 @@ class TrainerMon:
|
|
|
378
391
|
context.step += 1
|
|
379
392
|
self.dynamic_monitor(optimizer)
|
|
380
393
|
|
|
381
|
-
|
|
394
|
+
|
|
395
|
+
def patch_step(func, optimizer):
|
|
396
|
+
def wrapper(*args, **kwargs):
|
|
397
|
+
for hook in self.pre_step_hooks:
|
|
398
|
+
hook(optimizer, args, kwargs)
|
|
399
|
+
out = func(*args, **kwargs)
|
|
400
|
+
for hook in self.post_step_hooks:
|
|
401
|
+
hook(optimizer, args, kwargs)
|
|
402
|
+
step_final_hook(optimizer, args, kwargs)
|
|
403
|
+
return out
|
|
404
|
+
return wrapper
|
|
405
|
+
|
|
406
|
+
if self.is_mindtorch:
|
|
407
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
408
|
+
else:
|
|
409
|
+
optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer)
|
|
410
|
+
|
|
382
411
|
return
|
|
383
412
|
|
|
384
413
|
def dynamic_monitor(self, optimizer):
|
|
@@ -413,7 +442,7 @@ class TrainerMon:
|
|
|
413
442
|
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
414
443
|
return
|
|
415
444
|
|
|
416
|
-
self._remove_all_hooks()
|
|
445
|
+
self._remove_all_hooks(optimizer)
|
|
417
446
|
self.register_hooks(optimizer)
|
|
418
447
|
|
|
419
448
|
def register_hooks(self, optimizer):
|
|
@@ -438,45 +467,36 @@ class TrainerMon:
|
|
|
438
467
|
|
|
439
468
|
hooked_count = 0
|
|
440
469
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
441
|
-
if not
|
|
470
|
+
if not is_valid_instance(model_chunk):
|
|
442
471
|
logger.info("Target Model is not Cell")
|
|
443
472
|
continue
|
|
444
473
|
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
445
|
-
targets = [x for x, _ in model_chunk
|
|
474
|
+
targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys()
|
|
446
475
|
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
447
476
|
logger.info(f"> {hooked_count} modules are monitored.")
|
|
448
477
|
|
|
449
478
|
def hook_optimizer(self, optimizer):
|
|
450
|
-
def
|
|
479
|
+
def optimizer_pre_step_hook(opt, *args, **kwargs):
|
|
451
480
|
context = self.optimizer_context[opt]
|
|
452
481
|
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
453
482
|
self.collect_times):
|
|
454
483
|
return
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
for param in v_list:
|
|
472
|
-
name = param.name
|
|
473
|
-
if is_select and name not in self.targets:
|
|
474
|
-
continue
|
|
475
|
-
get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
|
|
476
|
-
if self.param_distribution:
|
|
477
|
-
for param in param_list:
|
|
478
|
-
get_single_metrics(self.ops, param.name, param, context.param_metric)
|
|
479
|
-
self.generate_wgrad_metrics()
|
|
484
|
+
|
|
485
|
+
grad_dict = {}
|
|
486
|
+
if self.wg_distribution:
|
|
487
|
+
grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name)
|
|
488
|
+
|
|
489
|
+
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
490
|
+
if self.is_mindtorch:
|
|
491
|
+
context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
|
|
492
|
+
context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
493
|
+
else:
|
|
494
|
+
context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
|
|
495
|
+
|
|
496
|
+
self.generate_wgrad_metrics(grad_dict)
|
|
497
|
+
self.generate_mv_metrics(context)
|
|
498
|
+
self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
|
|
499
|
+
|
|
480
500
|
metric_dict = {}
|
|
481
501
|
for cc in self.cc_context.values():
|
|
482
502
|
cc.aggregate()
|
|
@@ -488,63 +508,86 @@ class TrainerMon:
|
|
|
488
508
|
context.metric_dict = metric_dict
|
|
489
509
|
return
|
|
490
510
|
|
|
491
|
-
def
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
511
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
512
|
+
context = self.optimizer_context[optimizer]
|
|
513
|
+
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
514
|
+
|
|
495
515
|
|
|
496
516
|
if self.optimizer_hooked or not self.is_target_rank():
|
|
497
517
|
return
|
|
498
518
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
param_list = []
|
|
502
|
-
grad_names = []
|
|
503
|
-
for param in optimizer.get_parameters():
|
|
504
|
-
if MonitorConst.EXP_AVG_SQ in param.name:
|
|
505
|
-
v_list.append(param)
|
|
506
|
-
elif MonitorConst.EXP_AVG in param.name:
|
|
507
|
-
m_list.append(param)
|
|
508
|
-
elif param.name in ['global_step', 'learning_rate']:
|
|
509
|
-
pass
|
|
510
|
-
else:
|
|
511
|
-
param_list.append(param)
|
|
512
|
-
grad_names.append(param.name)
|
|
513
|
-
|
|
514
|
-
handle = optimizer.register_forward_pre_hook(
|
|
515
|
-
optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
|
|
516
|
-
self.handles['optimizer'].append(handle)
|
|
519
|
+
self.pre_step_hooks.append(optimizer_pre_step_hook)
|
|
520
|
+
self.post_step_hooks.append(optimizer_post_step_hook)
|
|
517
521
|
self.optimizer_hooked = True
|
|
518
522
|
return
|
|
519
523
|
|
|
520
|
-
def generate_wgrad_metrics(self):
|
|
524
|
+
def generate_wgrad_metrics(self, grad_dict):
|
|
521
525
|
if not self.wg_distribution:
|
|
522
|
-
return
|
|
526
|
+
return
|
|
523
527
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
527
|
-
except Exception as e:
|
|
528
|
-
logger.warning(f"An error occurred while generating wgrad pre metrics")
|
|
529
|
-
return {}, {}
|
|
528
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
529
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
530
530
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
531
|
+
def generate_param_map(self, tag, param_tensor):
|
|
532
|
+
metrics = {}
|
|
533
|
+
if not self.is_mindtorch:
|
|
534
|
+
return param_tensor
|
|
535
|
+
for name in self.param2name.values():
|
|
536
|
+
key = get_summary_writer_tag_name(name, tag, self.rank)
|
|
537
|
+
self.register_param_call_id("optimizer_pre_step_hook", key)
|
|
538
|
+
if name not in param_tensor or param_tensor[name] is None:
|
|
538
539
|
continue
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
540
|
+
metrics[key] = param_tensor[name]
|
|
541
|
+
return metrics
|
|
542
|
+
|
|
543
|
+
def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM):
|
|
544
|
+
if not self.param_distribution:
|
|
545
|
+
return
|
|
546
|
+
tag2param = {
|
|
547
|
+
self.name2tag.get(name, {}).get(stage): param
|
|
548
|
+
for name, param in self.name2param.items()
|
|
549
|
+
if param.numel() != 0
|
|
550
|
+
}
|
|
551
|
+
get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric)
|
|
552
|
+
|
|
553
|
+
def get_mv_for_ms(self, opt):
|
|
554
|
+
if not self.mv_distribution:
|
|
546
555
|
return {}, {}
|
|
547
|
-
|
|
556
|
+
common_opt = opt
|
|
557
|
+
if not is_valid_instance(opt):
|
|
558
|
+
common_opt = getattr(opt, 'optimizer')
|
|
559
|
+
if not is_valid_instance(common_opt):
|
|
560
|
+
logger.warning("Optimizer is not valid, please check usage")
|
|
561
|
+
return {}, {}
|
|
562
|
+
m_dict = {}
|
|
563
|
+
v_dict = {}
|
|
564
|
+
for name, param in get_parameters(common_opt):
|
|
565
|
+
if MonitorConst.EXP_AVG_SQ in name:
|
|
566
|
+
v_dict[name] = param
|
|
567
|
+
elif MonitorConst.EXP_AVG in name:
|
|
568
|
+
m_dict[name] = param
|
|
569
|
+
return m_dict, v_dict
|
|
570
|
+
|
|
571
|
+
def generate_mv_metrics(self, opt_context):
|
|
572
|
+
if not self.mv_distribution:
|
|
573
|
+
return
|
|
574
|
+
opt_context.exp_avg_metric = {}
|
|
575
|
+
opt_context.exp_avg_sq_metric = {}
|
|
576
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
577
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
578
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
579
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
580
|
+
|
|
581
|
+
def write_stack_info(self):
|
|
582
|
+
stack_data = []
|
|
583
|
+
header = ["module_name", "stack_info"]
|
|
584
|
+
stack_data.append(header)
|
|
585
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
586
|
+
stack_data.append([fwd_context.module_name, fwd_context.stack])
|
|
587
|
+
filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv')
|
|
588
|
+
if not os.path.exists(filepath):
|
|
589
|
+
data_frame = pd.DataFrame(columns=stack_data)
|
|
590
|
+
write_df_to_csv(data_frame, filepath)
|
|
548
591
|
|
|
549
592
|
def write_xy_tb(self, step):
|
|
550
593
|
if not self.xy_distribution:
|
|
@@ -552,27 +595,32 @@ class TrainerMon:
|
|
|
552
595
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
553
596
|
if len(fwd_context.actv) == 0:
|
|
554
597
|
continue
|
|
555
|
-
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step,
|
|
598
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
|
|
556
599
|
fwd_context.actv.clear()
|
|
557
600
|
if self.grad_context.actv:
|
|
558
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step,
|
|
601
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
|
|
559
602
|
|
|
560
603
|
def write_param_tb(self, opt_context):
|
|
561
604
|
if not self.param_distribution:
|
|
562
605
|
return
|
|
563
|
-
|
|
606
|
+
param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k}
|
|
607
|
+
updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k}
|
|
608
|
+
self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM)
|
|
609
|
+
self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM)
|
|
564
610
|
|
|
565
611
|
def write_mv_tb(self, opt_context):
|
|
566
612
|
if not self.mv_distribution:
|
|
567
613
|
return
|
|
568
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step,
|
|
569
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
|
|
614
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG)
|
|
615
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step,
|
|
616
|
+
MonitorConst.EXP_AVG_SQ)
|
|
570
617
|
|
|
571
618
|
def write_grad_tb(self, step):
|
|
572
619
|
if not self.wg_distribution:
|
|
573
620
|
return
|
|
574
621
|
|
|
575
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced'
|
|
622
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
|
|
623
|
+
use_micro_step=self.monitor_mbs_grad)
|
|
576
624
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
577
625
|
|
|
578
626
|
def is_target_rank(self):
|
|
@@ -580,13 +628,38 @@ class TrainerMon:
|
|
|
580
628
|
return False
|
|
581
629
|
return True
|
|
582
630
|
|
|
583
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
584
|
-
|
|
585
|
-
|
|
631
|
+
def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor):
|
|
632
|
+
"""
|
|
633
|
+
:param module_name: str of module name
|
|
634
|
+
:param suffix:
|
|
635
|
+
:param tag:
|
|
636
|
+
:param tensor: torch.tensor or tuple/list of torch.tensor
|
|
637
|
+
:return: tensor_map
|
|
638
|
+
"""
|
|
639
|
+
tensor_map = {}
|
|
586
640
|
if isinstance(tensor, Tensor):
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
641
|
+
tensor = [tensor]
|
|
642
|
+
if isinstance(tensor, tuple) or isinstance(tensor, list):
|
|
643
|
+
if len(tensor) == 1:
|
|
644
|
+
key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank)
|
|
645
|
+
self.register_param_call_id("_hook_module", key)
|
|
646
|
+
tensor_map[key] = tensor[0]
|
|
647
|
+
else:
|
|
648
|
+
for i, tensor_i in enumerate(tensor):
|
|
649
|
+
key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank)
|
|
650
|
+
self.register_param_call_id("_hook_module", key)
|
|
651
|
+
tensor_map[key] = tensor_i
|
|
652
|
+
return tensor_map
|
|
653
|
+
|
|
654
|
+
def register_param_call_id(self, hook_name: str, key: str):
|
|
655
|
+
"""
|
|
656
|
+
:param hook_name:
|
|
657
|
+
:param key: str, '0:relu_0/output_grad'
|
|
658
|
+
:return:
|
|
659
|
+
"""
|
|
660
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
661
|
+
self.param_name_call_id[key] = self.call_id
|
|
662
|
+
self.call_id += 1
|
|
590
663
|
|
|
591
664
|
def _register_param_name(self):
|
|
592
665
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -595,8 +668,7 @@ class TrainerMon:
|
|
|
595
668
|
|
|
596
669
|
def _register_chunk(self, model_chunk, prefix):
|
|
597
670
|
index = 0
|
|
598
|
-
for param in
|
|
599
|
-
param_name = param.name
|
|
671
|
+
for param_name, param in get_parameters(model_chunk):
|
|
600
672
|
if not param.requires_grad:
|
|
601
673
|
continue
|
|
602
674
|
if self._is_target_param(param_name, param, prefix):
|
|
@@ -611,25 +683,37 @@ class TrainerMon:
|
|
|
611
683
|
self.duplicate_param[name] = True
|
|
612
684
|
if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
|
|
613
685
|
self.duplicate_param[name] = True
|
|
686
|
+
keywords = [
|
|
687
|
+
MonitorConst.PRE_GRAD,
|
|
688
|
+
MonitorConst.POST_GRAD,
|
|
689
|
+
MonitorConst.PRE_PARAM,
|
|
690
|
+
MonitorConst.POST_PARAM
|
|
691
|
+
]
|
|
614
692
|
self.name2tag[name] = {
|
|
615
|
-
|
|
616
|
-
|
|
693
|
+
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
694
|
+
for k in keywords
|
|
617
695
|
}
|
|
618
696
|
index += 1
|
|
619
697
|
|
|
620
698
|
def _hook_module(self, target_names, module, vpp_stage=''):
|
|
621
|
-
if not
|
|
699
|
+
if not is_valid_instance(module):
|
|
622
700
|
# nothing to hook
|
|
623
701
|
return 0
|
|
624
702
|
|
|
625
|
-
def fwd_hook_fun(module,
|
|
703
|
+
def fwd_hook_fun(module, args, kwargs, module_output, name):
|
|
704
|
+
|
|
705
|
+
module_input = [tensor for tensor in args if isinstance(tensor, Tensor)]
|
|
706
|
+
if kwargs:
|
|
707
|
+
kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
|
|
708
|
+
module_input.extend(kwargs_tensors)
|
|
709
|
+
|
|
626
710
|
if module not in self.module_fwd_hook_context_by_module:
|
|
627
711
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
628
712
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
629
713
|
if not context.struct:
|
|
630
714
|
context.struct = {
|
|
631
|
-
|
|
632
|
-
|
|
715
|
+
Const.INPUT: get_param_struct(module_input),
|
|
716
|
+
Const.OUTPUT: get_param_struct(module_output)
|
|
633
717
|
}
|
|
634
718
|
if self.print_struct:
|
|
635
719
|
self.module_struct[context.module_name].update(context.struct)
|
|
@@ -640,31 +724,18 @@ class TrainerMon:
|
|
|
640
724
|
self.collect_times):
|
|
641
725
|
step_accumulates_one(context, self.micro_batch_number)
|
|
642
726
|
return
|
|
643
|
-
if not context.format_by_arg:
|
|
644
|
-
context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
|
|
645
|
-
context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
|
|
646
|
-
if not context.format_by_arg:
|
|
647
|
-
return
|
|
648
|
-
if not context.verified:
|
|
649
|
-
if not context.ignore_in:
|
|
650
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
|
|
651
|
-
module_input, context.module_name,
|
|
652
|
-
MonitorConst.ACTV_IN)
|
|
653
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
|
|
654
|
-
module_output, context.module_name,
|
|
655
|
-
MonitorConst.ACTV_OUT)
|
|
656
|
-
context.verified = True
|
|
657
727
|
|
|
658
728
|
tbtag_tensor_map = {}
|
|
659
|
-
if not context.ignore_in:
|
|
660
|
-
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
661
|
-
tbtag_tensor_map.update(
|
|
662
|
-
self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
|
|
663
|
-
cared_input))
|
|
664
|
-
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
665
729
|
tbtag_tensor_map.update(
|
|
666
|
-
self.build_tbtag_tensor_map(
|
|
667
|
-
|
|
730
|
+
self.build_tbtag_tensor_map(
|
|
731
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
732
|
+
MonitorConst.ACTV, module_input))
|
|
733
|
+
module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
|
|
734
|
+
if isinstance(module_output, tuple) else module_output
|
|
735
|
+
tbtag_tensor_map.update(
|
|
736
|
+
self.build_tbtag_tensor_map(
|
|
737
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
738
|
+
MonitorConst.ACTV, module_output))
|
|
668
739
|
try:
|
|
669
740
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
670
741
|
except Exception as e:
|
|
@@ -689,31 +760,17 @@ class TrainerMon:
|
|
|
689
760
|
step_accumulates_one(context, self.micro_batch_number)
|
|
690
761
|
return
|
|
691
762
|
|
|
692
|
-
if
|
|
693
|
-
context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
|
|
694
|
-
context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
|
|
695
|
-
if not context.format_by_arg:
|
|
696
|
-
return
|
|
697
|
-
if not context.verified:
|
|
698
|
-
if not context.ignore_in:
|
|
699
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
|
|
700
|
-
input_grad, context.module_name,
|
|
701
|
-
MonitorConst.ACTVGRAD_IN)
|
|
702
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
|
|
703
|
-
output_grad, context.module_name,
|
|
704
|
-
MonitorConst.ACTVGRAD_OUT)
|
|
705
|
-
context.verified = True
|
|
706
|
-
|
|
763
|
+
valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)]
|
|
707
764
|
tbtag_tensor_map = {}
|
|
708
|
-
if not context.ignore_in:
|
|
709
|
-
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
710
|
-
tbtag_tensor_map.update(
|
|
711
|
-
self.build_tbtag_tensor_map(
|
|
712
|
-
f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
|
|
713
|
-
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
714
765
|
tbtag_tensor_map.update(
|
|
715
|
-
self.build_tbtag_tensor_map(
|
|
716
|
-
|
|
766
|
+
self.build_tbtag_tensor_map(
|
|
767
|
+
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
768
|
+
MonitorConst.ACTVGRAD, valid_input_grad))
|
|
769
|
+
|
|
770
|
+
tbtag_tensor_map.update(
|
|
771
|
+
self.build_tbtag_tensor_map(
|
|
772
|
+
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
773
|
+
MonitorConst.ACTVGRAD, output_grad))
|
|
717
774
|
|
|
718
775
|
if context.micro_step == 0 and context.actvgrad:
|
|
719
776
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -727,21 +784,39 @@ class TrainerMon:
|
|
|
727
784
|
step_accumulates_one(context, self.micro_batch_number)
|
|
728
785
|
return
|
|
729
786
|
|
|
730
|
-
def
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
787
|
+
def fwd_hook_register(module, fwd_hook_fun, name):
|
|
788
|
+
if mindspore.__version__ >= '2.6.0':
|
|
789
|
+
def wrapper(module, args, kwargs, module_output):
|
|
790
|
+
return fwd_hook_fun(module, args, kwargs, module_output, name)
|
|
791
|
+
return module.register_forward_hook(wrapper, with_kwargs=True)
|
|
792
|
+
|
|
793
|
+
else:
|
|
794
|
+
def wrapper(module, args, module_output):
|
|
795
|
+
return fwd_hook_fun(module, args, None, module_output, name)
|
|
796
|
+
return module.register_forward_hook(wrapper)
|
|
797
|
+
|
|
798
|
+
def stack_hook(module, args, kwargs, module_output, name):
|
|
799
|
+
if module not in self.module_fwd_hook_context_by_module:
|
|
800
|
+
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
801
|
+
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
802
|
+
context.stack = analyze_api_call_stack(name)
|
|
803
|
+
return
|
|
734
804
|
|
|
735
805
|
if self.backward_only and self.forward_only:
|
|
736
806
|
logger.warning('not enable backward_only and forward_only simultaneously')
|
|
737
807
|
hooked_count = 0
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
808
|
+
|
|
809
|
+
for module_name, submodule in get_submodules(module):
|
|
810
|
+
if self.stack_info:
|
|
811
|
+
name = vpp_stage + squash_param_name(module_name)
|
|
812
|
+
handle = fwd_hook_register(submodule, stack_hook, name=name)
|
|
813
|
+
self.handles["stack"].append(handle)
|
|
814
|
+
name = self._is_target_module(module_name, target_names, vpp_stage)
|
|
815
|
+
if not name:
|
|
816
|
+
continue
|
|
817
|
+
if self.xy_distribution or self.print_struct:
|
|
743
818
|
if not self.backward_only:
|
|
744
|
-
handle = submodule
|
|
819
|
+
handle = fwd_hook_register(submodule, fwd_hook_fun, name=name)
|
|
745
820
|
self.handles['xy'].append(handle)
|
|
746
821
|
if not self.forward_only:
|
|
747
822
|
handle = submodule.register_backward_hook(bwd_hook_fun)
|
|
@@ -760,22 +835,30 @@ class TrainerMon:
|
|
|
760
835
|
context = self.grad_context
|
|
761
836
|
|
|
762
837
|
@_no_grad()
|
|
763
|
-
def param_hook(grad, context_dict, param,
|
|
838
|
+
def param_hook(grad, context_dict, param, name):
|
|
839
|
+
key = name
|
|
840
|
+
if self.monitor_mbs_grad:
|
|
841
|
+
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
842
|
+
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
843
|
+
self.register_param_call_id("param_hook", key)
|
|
764
844
|
param.micro_step += 1
|
|
765
|
-
|
|
845
|
+
|
|
846
|
+
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
847
|
+
context_dict[key] = grad
|
|
766
848
|
if param.micro_step == self.micro_batch_number:
|
|
767
849
|
param.micro_step = 0
|
|
768
|
-
context_dict[key] = grad
|
|
769
850
|
|
|
770
|
-
def param_hook_wrapper(param_hook, context_dict, param,
|
|
851
|
+
def param_hook_wrapper(param_hook, context_dict, param, name):
|
|
771
852
|
def wrapper(grad):
|
|
772
|
-
return param_hook(grad, context_dict, param,
|
|
853
|
+
return param_hook(grad, context_dict, param, name)
|
|
854
|
+
|
|
773
855
|
return wrapper
|
|
774
856
|
|
|
857
|
+
logger.info("hooking weights.")
|
|
775
858
|
for param, name in self.param2name.items():
|
|
776
|
-
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
777
859
|
setattr(param, 'micro_step', 0)
|
|
778
|
-
handle = param.register_hook(
|
|
860
|
+
handle = param.register_hook(
|
|
861
|
+
param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
|
|
779
862
|
self.handles['wgrads'].append(handle)
|
|
780
863
|
self.weight_hooked = True
|
|
781
864
|
|
|
@@ -801,17 +884,7 @@ class TrainerMon:
|
|
|
801
884
|
return pattern
|
|
802
885
|
return ""
|
|
803
886
|
|
|
804
|
-
def
|
|
805
|
-
"""
|
|
806
|
-
:param hook_name:
|
|
807
|
-
:param key: str, '0:relu_0/output_grad'
|
|
808
|
-
:return:
|
|
809
|
-
"""
|
|
810
|
-
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
811
|
-
self.param_name_call_id[key] = self.call_id
|
|
812
|
-
self.call_id += 1
|
|
813
|
-
|
|
814
|
-
def _remove_all_hooks(self):
|
|
887
|
+
def _remove_all_hooks(self, optimizer):
|
|
815
888
|
# 清空hook handle
|
|
816
889
|
for handle in self.handles['xy']:
|
|
817
890
|
handle.remove()
|
|
@@ -829,9 +902,8 @@ class TrainerMon:
|
|
|
829
902
|
self.weight_hooked = False
|
|
830
903
|
|
|
831
904
|
if self.optimizer_hooked:
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
self.handles['optimizer'].clear()
|
|
905
|
+
self.pre_step_hooks.clear()
|
|
906
|
+
self.post_step_hooks.clear()
|
|
835
907
|
for _, context in self.optimizer_context.items():
|
|
836
908
|
context.reset()
|
|
837
909
|
self.optimizer_hooked = False
|
|
@@ -870,4 +942,4 @@ class TrainerMon:
|
|
|
870
942
|
except Exception as e:
|
|
871
943
|
logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
|
|
872
944
|
logger.info("Finish monitor")
|
|
873
|
-
self._remove_all_hooks()
|
|
945
|
+
self._remove_all_hooks(optimizer)
|