mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -22,12 +22,13 @@ from functools import partial
|
|
|
22
22
|
import pytz
|
|
23
23
|
import torch
|
|
24
24
|
import torch.distributed as dist
|
|
25
|
-
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
26
25
|
from torch.utils.hooks import BackwardHook
|
|
27
26
|
|
|
28
|
-
from msprobe.core.common.const import MonitorConst
|
|
27
|
+
from msprobe.core.common.const import MonitorConst, Const
|
|
29
28
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
29
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
30
|
from msprobe.pytorch.common.log import logger
|
|
31
|
+
from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
|
|
31
32
|
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
32
33
|
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
33
34
|
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
@@ -37,15 +38,16 @@ from msprobe.pytorch.monitor.features import get_sign_matches
|
|
|
37
38
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
38
39
|
TensorMetrics, squash_param_name
|
|
39
40
|
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
40
|
-
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
41
|
-
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops,
|
|
42
|
-
get_output_base_dir, get_target_output_dir
|
|
41
|
+
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
42
|
+
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
43
|
+
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
43
44
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
44
45
|
|
|
45
46
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
46
47
|
if not torch_version_above_or_equal_2:
|
|
47
48
|
raise ValueError("monitor require torch>=2.0")
|
|
48
49
|
|
|
50
|
+
|
|
49
51
|
FORMAT_MAPPING = {
|
|
50
52
|
MonitorConst.TENSORBOARD: SummaryWriterWithAD,
|
|
51
53
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
@@ -85,9 +87,6 @@ class ModuleHookContext:
|
|
|
85
87
|
:param target_config: target obj in config json.
|
|
86
88
|
:return:
|
|
87
89
|
"""
|
|
88
|
-
valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
89
|
-
if key_name not in valid_key:
|
|
90
|
-
raise ValueError(f"key({key_name}) error, valid_key: {valid_key}")
|
|
91
90
|
cared = target_config.get(self.module_name, self.struct)
|
|
92
91
|
if key_name in cared:
|
|
93
92
|
target_module_config = cared[key_name]
|
|
@@ -178,20 +177,17 @@ class GradContext:
|
|
|
178
177
|
class TrainerMon:
|
|
179
178
|
tensor_metrics = TensorMetrics()
|
|
180
179
|
|
|
180
|
+
# 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
|
|
181
181
|
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
|
|
182
|
-
"""
|
|
183
|
-
opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer"
|
|
184
|
-
"""
|
|
185
182
|
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
186
183
|
self.config_file_path = config_file_path
|
|
187
184
|
self.process_group = get_process_group(process_group)
|
|
188
185
|
self.params_have_main_grad = params_have_main_grad
|
|
189
|
-
self.opt_ty = opt_ty
|
|
190
|
-
self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty)
|
|
191
186
|
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
192
187
|
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
193
188
|
self.origin_step_func = None
|
|
194
|
-
self.
|
|
189
|
+
self.origin_start_grad_sync = None
|
|
190
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
195
191
|
self.config = load_json(config_file_path)
|
|
196
192
|
validate_config(self.config)
|
|
197
193
|
|
|
@@ -219,13 +215,16 @@ class TrainerMon:
|
|
|
219
215
|
self.pp_stage = 0
|
|
220
216
|
self.group_mates = [0]
|
|
221
217
|
|
|
222
|
-
# TYPE2: 只会在
|
|
218
|
+
# TYPE2: 只会在set_monitor()主调中赋值的变量
|
|
223
219
|
self.model = None
|
|
224
220
|
self.vpp = False
|
|
225
221
|
self.dp_group = None
|
|
226
222
|
self.tp_group = None
|
|
227
223
|
self.enable_megatron = False
|
|
228
224
|
self.micro_batch_number = 1
|
|
225
|
+
self.optimizer_class = None
|
|
226
|
+
self.optimizer_mon = None
|
|
227
|
+
self.optimizer_trans = None
|
|
229
228
|
|
|
230
229
|
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
231
230
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
@@ -253,7 +252,7 @@ class TrainerMon:
|
|
|
253
252
|
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
254
253
|
if self.dynamic_enable:
|
|
255
254
|
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
256
|
-
f"please make sure you have '
|
|
255
|
+
f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
|
|
257
256
|
self.monitoring = False
|
|
258
257
|
else:
|
|
259
258
|
self.set_config()
|
|
@@ -273,10 +272,6 @@ class TrainerMon:
|
|
|
273
272
|
def ops(self, value):
|
|
274
273
|
self._ops = validate_ops(value)
|
|
275
274
|
|
|
276
|
-
@staticmethod
|
|
277
|
-
def set_wrapped_optimizer(_wrapped_optimizer):
|
|
278
|
-
OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer)
|
|
279
|
-
|
|
280
275
|
@staticmethod
|
|
281
276
|
def has_register_backward_hook(module_name, module):
|
|
282
277
|
if hasattr(module, '_backward_hooks') and \
|
|
@@ -308,7 +303,7 @@ class TrainerMon:
|
|
|
308
303
|
self.has_collect_times = 0 # 重设采集计数器
|
|
309
304
|
self.print_struct = self.config.get("print_struct", False)
|
|
310
305
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
311
|
-
self.format = self.config.get('format',
|
|
306
|
+
self.format = self.config.get('format', MonitorConst.CSV)
|
|
312
307
|
self.eps = self.config.get('eps', 1e-8)
|
|
313
308
|
self.ops = self.config.get('ops', [])
|
|
314
309
|
self.ndigits = self.config.get('ndigits', 6)
|
|
@@ -330,8 +325,6 @@ class TrainerMon:
|
|
|
330
325
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
331
326
|
self.cc_logged_stack = defaultdict(set)
|
|
332
327
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
333
|
-
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
334
|
-
api_register.redirect_api()
|
|
335
328
|
|
|
336
329
|
self.common_info()
|
|
337
330
|
|
|
@@ -344,7 +337,13 @@ class TrainerMon:
|
|
|
344
337
|
|
|
345
338
|
# 初始化writer, 创建输出目录
|
|
346
339
|
if self.format not in FORMAT_MAPPING:
|
|
347
|
-
|
|
340
|
+
logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
341
|
+
self.format = MonitorConst.CSV
|
|
342
|
+
|
|
343
|
+
if self.ur_distribution and self.format != 'tensorboard':
|
|
344
|
+
logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
|
|
345
|
+
self.ur_distribution = False
|
|
346
|
+
|
|
348
347
|
writer = FORMAT_MAPPING[self.format]
|
|
349
348
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
350
349
|
|
|
@@ -365,19 +364,6 @@ class TrainerMon:
|
|
|
365
364
|
self.rank)
|
|
366
365
|
self.anomaly_data_writer.init_detected_json()
|
|
367
366
|
|
|
368
|
-
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
369
|
-
rank = None
|
|
370
|
-
if dist.is_initialized():
|
|
371
|
-
rank = dist.get_rank()
|
|
372
|
-
if (rank not in rank_list) and len(rank_list) != 0:
|
|
373
|
-
return
|
|
374
|
-
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
375
|
-
|
|
376
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
377
|
-
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
378
|
-
self._register_param_call_id("_hook_module", key)
|
|
379
|
-
return {key: tensor}
|
|
380
|
-
|
|
381
367
|
def common_info(self):
|
|
382
368
|
if not self.xy_distribution:
|
|
383
369
|
logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
|
|
@@ -393,105 +379,39 @@ class TrainerMon:
|
|
|
393
379
|
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
|
|
394
380
|
if not self.cc_distribution.get('enable', False):
|
|
395
381
|
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
396
|
-
if not self.opt_ty:
|
|
397
|
-
if self.ur_distribution:
|
|
398
|
-
raise Exception("ur_distribution cannot be enabled with unknown optimizer.")
|
|
399
|
-
if self.mv_distribution:
|
|
400
|
-
raise Exception("mv_distribution cannot be enabled with unknown optimizer.")
|
|
401
|
-
|
|
402
|
-
def hook_modules(self):
|
|
403
|
-
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
404
|
-
return
|
|
405
|
-
|
|
406
|
-
targets = self.config['targets']
|
|
407
|
-
module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key]
|
|
408
|
-
for key in module_in_all_stage:
|
|
409
|
-
struct = targets.pop(key)
|
|
410
|
-
targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
411
|
-
|
|
412
|
-
hooked_count = 0
|
|
413
|
-
for vpp_stage, model_chunk in enumerate(self.model):
|
|
414
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}'
|
|
415
|
-
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
416
|
-
'targets'].keys()
|
|
417
|
-
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
418
|
-
|
|
419
|
-
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
420
|
-
|
|
421
|
-
def clone_if_tensor(args):
|
|
422
|
-
if isinstance(args, tuple):
|
|
423
|
-
return tuple([clone_if_tensor(arg) for arg in args])
|
|
424
|
-
elif isinstance(args, torch.Tensor):
|
|
425
|
-
return args.clone()
|
|
426
|
-
else:
|
|
427
|
-
return args
|
|
428
|
-
|
|
429
|
-
@torch.no_grad
|
|
430
|
-
def wrap_hook_setup(setup):
|
|
431
|
-
def wrapped_setup(*args, **kwargs):
|
|
432
|
-
args = setup(*args, **kwargs)
|
|
433
|
-
args = clone_if_tensor(args)
|
|
434
|
-
return args
|
|
435
|
-
|
|
436
|
-
return wrapped_setup
|
|
437
|
-
|
|
438
|
-
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
439
|
-
|
|
440
|
-
return
|
|
441
|
-
|
|
442
|
-
def generate_param_metrics(self, opt_context):
|
|
443
|
-
if not self.param_distribution:
|
|
444
|
-
return
|
|
445
|
-
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
446
|
-
|
|
447
|
-
def generate_mv_metrics(self, opt_context):
|
|
448
|
-
if not self.mv_distribution:
|
|
449
|
-
return
|
|
450
|
-
opt_context.exp_avg_metric = {}
|
|
451
|
-
opt_context.exp_avg_sq_metric = {}
|
|
452
|
-
m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg)
|
|
453
|
-
v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq)
|
|
454
|
-
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
455
|
-
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
456
|
-
|
|
457
|
-
def generate_wgrad_metrics(self):
|
|
458
|
-
if not self.wg_distribution:
|
|
459
|
-
return {}, {}
|
|
460
382
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
if
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
471
|
-
continue
|
|
472
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
473
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
474
|
-
grad_dict[tag] = grad
|
|
383
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
384
|
+
def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
|
|
385
|
+
start_iteration=0):
|
|
386
|
+
if optimizer is None:
|
|
387
|
+
optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
|
|
388
|
+
if optimizer is None:
|
|
389
|
+
logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
|
|
390
|
+
return
|
|
391
|
+
self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
|
|
475
392
|
|
|
476
|
-
|
|
477
|
-
|
|
393
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
394
|
+
def set_wrapped_optimizer(self, optimizer):
|
|
395
|
+
self.optimizer_trans = optimizer
|
|
478
396
|
|
|
479
|
-
def
|
|
397
|
+
def set_monitor(
|
|
480
398
|
self,
|
|
481
399
|
model,
|
|
400
|
+
optimizer,
|
|
482
401
|
grad_acc_steps=1,
|
|
483
|
-
optimizer=None,
|
|
484
402
|
tp_group=None,
|
|
485
403
|
dp_group=None,
|
|
486
404
|
start_iteration=0
|
|
487
405
|
):
|
|
488
406
|
"""External interface"""
|
|
407
|
+
grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
|
|
489
408
|
global start_step
|
|
490
409
|
start_step = start_iteration
|
|
491
410
|
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
492
411
|
self.micro_batch_number = grad_acc_steps
|
|
493
412
|
self.dp_group = dp_group
|
|
494
413
|
self.tp_group = tp_group
|
|
414
|
+
self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer)
|
|
495
415
|
self.hook_step_final(optimizer)
|
|
496
416
|
if not isinstance(model, list):
|
|
497
417
|
model = [model]
|
|
@@ -507,8 +427,24 @@ class TrainerMon:
|
|
|
507
427
|
self.hook_optimizer(optimizer)
|
|
508
428
|
self._patch_grad_sync()
|
|
509
429
|
self.hook_modules()
|
|
430
|
+
if self.cc_distribution.get('enable', False):
|
|
431
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
432
|
+
api_register.redirect_api()
|
|
510
433
|
self.monitoring = True
|
|
511
434
|
|
|
435
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
436
|
+
rank = None
|
|
437
|
+
if dist.is_initialized():
|
|
438
|
+
rank = dist.get_rank()
|
|
439
|
+
if (rank not in rank_list) and len(rank_list) != 0:
|
|
440
|
+
return
|
|
441
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
442
|
+
|
|
443
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
444
|
+
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
445
|
+
self._register_param_call_id("_hook_module", key)
|
|
446
|
+
return {key: tensor}
|
|
447
|
+
|
|
512
448
|
def generate_param_map(self, tag, param_tensor):
|
|
513
449
|
metrics = {}
|
|
514
450
|
for name in self.param2name.values():
|
|
@@ -519,6 +455,44 @@ class TrainerMon:
|
|
|
519
455
|
metrics[key] = param_tensor[name]
|
|
520
456
|
return metrics
|
|
521
457
|
|
|
458
|
+
def generate_param_metrics(self, opt_context):
|
|
459
|
+
if not self.param_distribution:
|
|
460
|
+
return
|
|
461
|
+
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
462
|
+
|
|
463
|
+
def generate_mv_metrics(self, opt_context):
|
|
464
|
+
if not self.mv_distribution:
|
|
465
|
+
return
|
|
466
|
+
opt_context.exp_avg_metric = {}
|
|
467
|
+
opt_context.exp_avg_sq_metric = {}
|
|
468
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
469
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
470
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
471
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
472
|
+
|
|
473
|
+
def generate_wgrad_metrics(self):
|
|
474
|
+
if not self.wg_distribution:
|
|
475
|
+
return {}, {}
|
|
476
|
+
|
|
477
|
+
if self.weight_hooked:
|
|
478
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
479
|
+
|
|
480
|
+
grad_dict = {}
|
|
481
|
+
for param, name in self.param2name.items():
|
|
482
|
+
if self.duplicate_param.get(name, False):
|
|
483
|
+
continue
|
|
484
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
485
|
+
if grad is None:
|
|
486
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
487
|
+
continue
|
|
488
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
489
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
490
|
+
grad_dict[tag] = grad
|
|
491
|
+
|
|
492
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
493
|
+
unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
|
|
494
|
+
return self.grad_context.post, unreduced_grad
|
|
495
|
+
|
|
522
496
|
def generate_xy_metrics(self):
|
|
523
497
|
actv = {}
|
|
524
498
|
for fwd_context in self.module_fwd_hook_context_by_module.values():
|
|
@@ -529,6 +503,8 @@ class TrainerMon:
|
|
|
529
503
|
return actv, actv_grad
|
|
530
504
|
|
|
531
505
|
def reload_xy(self, xy_distribution=False):
|
|
506
|
+
logger.warning("reload_xy() is deprecated and will be removed in a future version. "
|
|
507
|
+
"Use DYNAMIC_MONITOR instead.")
|
|
532
508
|
self.xy_distribution = xy_distribution
|
|
533
509
|
|
|
534
510
|
for handle in self.handles['xy']:
|
|
@@ -547,21 +523,23 @@ class TrainerMon:
|
|
|
547
523
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
548
524
|
if len(fwd_context.actv) == 0:
|
|
549
525
|
continue
|
|
550
|
-
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step,
|
|
526
|
+
self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV)
|
|
551
527
|
fwd_context.actv.clear()
|
|
552
528
|
if self.grad_context.actv:
|
|
553
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step,
|
|
529
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
|
|
554
530
|
|
|
555
531
|
def write_param_tb(self, opt_context):
|
|
556
532
|
if not self.param_distribution:
|
|
557
533
|
return
|
|
558
|
-
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step,
|
|
534
|
+
self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM)
|
|
559
535
|
|
|
560
536
|
def write_mv_tb(self, opt_context):
|
|
561
537
|
if not self.mv_distribution:
|
|
562
538
|
return
|
|
563
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
564
|
-
|
|
539
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
540
|
+
opt_context.step, MonitorConst.EXP_AVG)
|
|
541
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
542
|
+
opt_context.step, MonitorConst.EXP_AVG_SQ)
|
|
565
543
|
|
|
566
544
|
def write_grad_tb(self, step):
|
|
567
545
|
if not self.wg_distribution:
|
|
@@ -573,7 +551,7 @@ class TrainerMon:
|
|
|
573
551
|
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
574
552
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
575
553
|
|
|
576
|
-
def hook_optimizer(self, optimizer
|
|
554
|
+
def hook_optimizer(self, optimizer):
|
|
577
555
|
# in DDP by default use params_have_main_grad
|
|
578
556
|
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
579
557
|
context = self.optimizer_context[optimizer]
|
|
@@ -592,15 +570,13 @@ class TrainerMon:
|
|
|
592
570
|
# skip generate metrics
|
|
593
571
|
if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0:
|
|
594
572
|
return
|
|
595
|
-
if
|
|
573
|
+
if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3
|
|
596
574
|
if not self.name2indices:
|
|
597
|
-
self.name2indices = self.
|
|
598
|
-
|
|
599
|
-
mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name,
|
|
600
|
-
self.name2indices)
|
|
575
|
+
self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer)
|
|
576
|
+
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices)
|
|
601
577
|
self.param2name = mv_result.grad
|
|
602
578
|
else:
|
|
603
|
-
mv_result = self.
|
|
579
|
+
mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name)
|
|
604
580
|
context.param_exp_avg = mv_result.exp_avg
|
|
605
581
|
context.param_exp_avg_sq = mv_result.exp_avg_sq
|
|
606
582
|
context.param_adam_update = mv_result.update
|
|
@@ -641,19 +617,13 @@ class TrainerMon:
|
|
|
641
617
|
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
642
618
|
out = func(*args, **kwargs)
|
|
643
619
|
return out
|
|
644
|
-
|
|
645
620
|
return wrapper
|
|
646
621
|
|
|
647
622
|
if self.optimizer_hooked:
|
|
648
623
|
return
|
|
649
624
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
self.handles['optimizer'] = []
|
|
653
|
-
else:
|
|
654
|
-
if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list):
|
|
655
|
-
step_pre_hook = register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
656
|
-
self.handles['optimizer'] = [step_pre_hook]
|
|
625
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
626
|
+
|
|
657
627
|
self.optimizer_hooked = True
|
|
658
628
|
return
|
|
659
629
|
|
|
@@ -677,11 +647,12 @@ class TrainerMon:
|
|
|
677
647
|
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
678
648
|
return
|
|
679
649
|
|
|
680
|
-
if config.get("
|
|
650
|
+
if config.get("dynamic_on", False):
|
|
681
651
|
try:
|
|
682
652
|
validate_config(config)
|
|
683
653
|
self.config = config
|
|
684
654
|
self.set_config()
|
|
655
|
+
self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
|
|
685
656
|
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
686
657
|
f"will start new hook at step{context.step}.")
|
|
687
658
|
except Exception as e:
|
|
@@ -729,6 +700,9 @@ class TrainerMon:
|
|
|
729
700
|
if self.anomaly_data_factory:
|
|
730
701
|
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
731
702
|
self.summary_writer.clear_anomalies()
|
|
703
|
+
|
|
704
|
+
if self.format == MonitorConst.TENSORBOARD:
|
|
705
|
+
chmod_tensorboard_dir(self.tensorboard_dir)
|
|
732
706
|
self.call_id = 0
|
|
733
707
|
self.param_name_call_id.clear()
|
|
734
708
|
|
|
@@ -745,11 +719,49 @@ class TrainerMon:
|
|
|
745
719
|
return out
|
|
746
720
|
return wrapper
|
|
747
721
|
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
722
|
+
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
723
|
+
self.origin_step_func = optimizer.__class__.step
|
|
724
|
+
return
|
|
725
|
+
|
|
726
|
+
def hook_modules(self):
|
|
727
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
728
|
+
return
|
|
729
|
+
|
|
730
|
+
targets = self.config['targets']
|
|
731
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
732
|
+
for key in module_in_all_stage:
|
|
733
|
+
struct = targets.pop(key)
|
|
734
|
+
targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
735
|
+
|
|
736
|
+
hooked_count = 0
|
|
737
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
738
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
739
|
+
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
740
|
+
'targets'].keys()
|
|
741
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
742
|
+
|
|
743
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
744
|
+
|
|
745
|
+
@recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
|
|
746
|
+
def clone_if_tensor(args):
|
|
747
|
+
if isinstance(args, tuple):
|
|
748
|
+
return tuple([clone_if_tensor(arg) for arg in args])
|
|
749
|
+
elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
|
|
750
|
+
return args.clone()
|
|
751
|
+
else:
|
|
752
|
+
return args
|
|
753
|
+
|
|
754
|
+
@torch.no_grad
|
|
755
|
+
def wrap_hook_setup(setup):
|
|
756
|
+
def wrapped_setup(*args, **kwargs):
|
|
757
|
+
args = setup(*args, **kwargs)
|
|
758
|
+
args = clone_if_tensor(args)
|
|
759
|
+
return args
|
|
760
|
+
|
|
761
|
+
return wrapped_setup
|
|
762
|
+
|
|
763
|
+
BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
|
|
764
|
+
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
753
765
|
return
|
|
754
766
|
|
|
755
767
|
def _remove_all_hooks(self, optimizer):
|
|
@@ -764,17 +776,28 @@ class TrainerMon:
|
|
|
764
776
|
bwd_context.reset()
|
|
765
777
|
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
766
778
|
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
779
|
+
if self.origin_start_grad_sync: # megatron
|
|
780
|
+
try:
|
|
781
|
+
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
782
|
+
Bucket.start_grad_sync = self.origin_start_grad_sync
|
|
783
|
+
logger.info("remove Bucket start_grad_sync")
|
|
784
|
+
except ImportError:
|
|
785
|
+
pass
|
|
786
|
+
try:
|
|
787
|
+
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
788
|
+
_ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
|
|
789
|
+
logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
|
|
790
|
+
except ImportError:
|
|
791
|
+
pass
|
|
792
|
+
else: # not megatron
|
|
793
|
+
for handle in self.handles['wgrads']:
|
|
794
|
+
handle.remove()
|
|
795
|
+
self.handles['wgrads'].clear()
|
|
796
|
+
self.weight_hooked = False
|
|
771
797
|
|
|
772
|
-
if
|
|
798
|
+
if self.optimizer_hooked:
|
|
773
799
|
optimizer.__class__.step = self.origin_step_func
|
|
774
|
-
|
|
775
|
-
for handle in self.handles['optimizer']:
|
|
776
|
-
handle.remove()
|
|
777
|
-
self.handles['optimizer'].clear()
|
|
800
|
+
|
|
778
801
|
for _, context in self.optimizer_context.items():
|
|
779
802
|
context.reset()
|
|
780
803
|
self.optimizer_hooked = False
|
|
@@ -782,6 +805,7 @@ class TrainerMon:
|
|
|
782
805
|
for handle in self.handles['cc']:
|
|
783
806
|
handle.remove()
|
|
784
807
|
self.handles['cc'].clear()
|
|
808
|
+
api_register.restore_api()
|
|
785
809
|
for _, context in self.cc_context.items():
|
|
786
810
|
context.reset()
|
|
787
811
|
|
|
@@ -800,17 +824,17 @@ class TrainerMon:
|
|
|
800
824
|
|
|
801
825
|
def _remove_all_hooks_final(self, optimizer):
|
|
802
826
|
if self.dynamic_enable:
|
|
803
|
-
# 结束后自动重置
|
|
827
|
+
# 结束后自动重置dynamic_on为False等待用户手动开启
|
|
804
828
|
try:
|
|
805
829
|
config = load_json(self.config_file_path)
|
|
806
|
-
config['
|
|
830
|
+
config['dynamic_on'] = False
|
|
807
831
|
save_json(self.config_file_path, config, indent=2)
|
|
808
832
|
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
809
833
|
self.config_timestamp = config_timestamp
|
|
810
834
|
logger.info(
|
|
811
|
-
"Finish monitor, set config'
|
|
835
|
+
"Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
|
|
812
836
|
except Exception as e:
|
|
813
|
-
logger.warning(f"Finish monitor, set config'
|
|
837
|
+
logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
|
|
814
838
|
logger.info("Finish monitor")
|
|
815
839
|
self._remove_all_hooks(optimizer)
|
|
816
840
|
|
|
@@ -871,7 +895,7 @@ class TrainerMon:
|
|
|
871
895
|
|
|
872
896
|
def _register_param_name(self):
|
|
873
897
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
874
|
-
prefix = f'{vpp_stage}{MonitorConst.
|
|
898
|
+
prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
875
899
|
self._register_chunk(model_chunk, prefix)
|
|
876
900
|
|
|
877
901
|
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
@@ -900,35 +924,37 @@ class TrainerMon:
|
|
|
900
924
|
context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
|
|
901
925
|
if not context.struct:
|
|
902
926
|
context.struct = {
|
|
903
|
-
|
|
904
|
-
|
|
927
|
+
Const.INPUT: get_param_struct(module_input),
|
|
928
|
+
Const.OUTPUT: get_param_struct(module_output)
|
|
905
929
|
}
|
|
906
930
|
if self.print_struct:
|
|
907
931
|
self.module_struct[context.module_name].update(context.struct)
|
|
908
932
|
return
|
|
909
933
|
if not context.format_by_arg:
|
|
910
|
-
context.set_format_by_arg(
|
|
911
|
-
context.set_format_by_arg(
|
|
934
|
+
context.set_format_by_arg(Const.INPUT, self.config['targets'])
|
|
935
|
+
context.set_format_by_arg(Const.OUTPUT, self.config['targets'])
|
|
912
936
|
if not context.format_by_arg:
|
|
913
937
|
return
|
|
914
938
|
if not context.verified:
|
|
915
|
-
context.focused_in_col = validate_config_spec(context.format_by_arg[
|
|
939
|
+
context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT],
|
|
916
940
|
module_input, context.module_name,
|
|
917
|
-
|
|
918
|
-
context.focused_out_col = validate_config_spec(context.format_by_arg[
|
|
941
|
+
Const.INPUT)
|
|
942
|
+
context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT],
|
|
919
943
|
module_output, context.module_name,
|
|
920
|
-
|
|
944
|
+
Const.OUTPUT)
|
|
921
945
|
context.verified = True
|
|
922
946
|
# expect output be tensor type
|
|
923
947
|
tbtag_tensor_map = {}
|
|
924
948
|
cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
|
|
925
949
|
tbtag_tensor_map.update(
|
|
926
|
-
self.build_tbtag_tensor_map(
|
|
927
|
-
|
|
950
|
+
self.build_tbtag_tensor_map(
|
|
951
|
+
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
952
|
+
MonitorConst.ACTV, cared_input))
|
|
928
953
|
cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
|
|
929
954
|
tbtag_tensor_map.update(
|
|
930
|
-
self.build_tbtag_tensor_map(
|
|
931
|
-
|
|
955
|
+
self.build_tbtag_tensor_map(
|
|
956
|
+
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
957
|
+
MonitorConst.ACTV, cared_output))
|
|
932
958
|
|
|
933
959
|
get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
|
|
934
960
|
context.micro_step += 1
|
|
@@ -940,35 +966,37 @@ class TrainerMon:
|
|
|
940
966
|
context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
|
|
941
967
|
if not context.struct:
|
|
942
968
|
context.struct = {
|
|
943
|
-
MonitorConst.
|
|
944
|
-
MonitorConst.
|
|
969
|
+
MonitorConst.INPUT_GRAD: get_param_struct(input_grad),
|
|
970
|
+
MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad)
|
|
945
971
|
}
|
|
946
972
|
if self.print_struct:
|
|
947
973
|
self.module_struct[context.module_name].update(context.struct)
|
|
948
974
|
return
|
|
949
975
|
if not context.format_by_arg:
|
|
950
|
-
context.set_format_by_arg(MonitorConst.
|
|
951
|
-
context.set_format_by_arg(MonitorConst.
|
|
976
|
+
context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets'])
|
|
977
|
+
context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets'])
|
|
952
978
|
if not context.format_by_arg:
|
|
953
979
|
return
|
|
954
980
|
if not context.verified:
|
|
955
|
-
context.focused_in_col = validate_config_spec(
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
context.focused_out_col = validate_config_spec(
|
|
959
|
-
|
|
960
|
-
|
|
981
|
+
context.focused_in_col = validate_config_spec(
|
|
982
|
+
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
983
|
+
input_grad, context.module_name, MonitorConst.INPUT_GRAD)
|
|
984
|
+
context.focused_out_col = validate_config_spec(
|
|
985
|
+
context.format_by_arg[MonitorConst.OUTPUT_GRAD],
|
|
986
|
+
output_grad, context.module_name, MonitorConst.OUTPUT_GRAD)
|
|
961
987
|
context.verified = True
|
|
962
988
|
|
|
963
989
|
tbtag_tensor_map = {}
|
|
964
990
|
cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
|
|
965
991
|
tbtag_tensor_map.update(
|
|
966
|
-
self.build_tbtag_tensor_map(
|
|
967
|
-
|
|
992
|
+
self.build_tbtag_tensor_map(
|
|
993
|
+
f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
994
|
+
MonitorConst.ACTV, cared_input_grad))
|
|
968
995
|
cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
|
|
969
996
|
tbtag_tensor_map.update(
|
|
970
|
-
self.build_tbtag_tensor_map(
|
|
971
|
-
|
|
997
|
+
self.build_tbtag_tensor_map(
|
|
998
|
+
f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
999
|
+
MonitorConst.ACTV, cared_output_grad))
|
|
972
1000
|
|
|
973
1001
|
if context.micro_step == 0 and context.actvgrad:
|
|
974
1002
|
logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
|
|
@@ -1006,7 +1034,10 @@ class TrainerMon:
|
|
|
1006
1034
|
def patch_sync(sync_grad_func):
|
|
1007
1035
|
def wrapper(bucket):
|
|
1008
1036
|
grad_dict = {}
|
|
1009
|
-
|
|
1037
|
+
# Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
|
|
1038
|
+
# When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
|
|
1039
|
+
# In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
|
|
1040
|
+
bucket_params_id_list = [id(params) for params in bucket.params]
|
|
1010
1041
|
for param, name in self.param2name.items():
|
|
1011
1042
|
if id(param) not in bucket_params_id_list:
|
|
1012
1043
|
continue
|
|
@@ -1025,18 +1056,28 @@ class TrainerMon:
|
|
|
1025
1056
|
|
|
1026
1057
|
return wrapper
|
|
1027
1058
|
|
|
1059
|
+
if not self.wg_distribution:
|
|
1060
|
+
return
|
|
1061
|
+
|
|
1028
1062
|
try:
|
|
1029
1063
|
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
1064
|
+
self.origin_start_grad_sync = Bucket.start_grad_sync
|
|
1065
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
|
|
1030
1066
|
self.enable_megatron = True
|
|
1067
|
+
logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
|
|
1031
1068
|
except ImportError:
|
|
1032
1069
|
self.enable_megatron = False
|
|
1033
1070
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1071
|
+
try:
|
|
1072
|
+
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
1073
|
+
self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
|
|
1074
|
+
_ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
|
|
1075
|
+
self.enable_megatron = True
|
|
1076
|
+
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1077
|
+
except ImportError:
|
|
1078
|
+
self.enable_megatron = False | self.enable_megatron
|
|
1036
1079
|
|
|
1037
|
-
if self.enable_megatron:
|
|
1038
|
-
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
1039
|
-
else:
|
|
1080
|
+
if not self.enable_megatron:
|
|
1040
1081
|
self._hook_weights()
|
|
1041
1082
|
|
|
1042
1083
|
def _hook_weights(self):
|
|
@@ -1049,10 +1090,14 @@ class TrainerMon:
|
|
|
1049
1090
|
if param.micro_step == self.micro_batch_number:
|
|
1050
1091
|
param.micro_step = 0
|
|
1051
1092
|
if self.params_have_main_grad:
|
|
1052
|
-
|
|
1093
|
+
grad = param.main_grad
|
|
1053
1094
|
else:
|
|
1054
|
-
|
|
1095
|
+
grad = param.grad
|
|
1096
|
+
if is_float8_tensor(grad):
|
|
1097
|
+
grad = grad.float()
|
|
1098
|
+
context_dict[key] = grad.clone()
|
|
1055
1099
|
|
|
1100
|
+
logger.info("hooking weights.")
|
|
1056
1101
|
for param, name in self.param2name.items():
|
|
1057
1102
|
key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
|
|
1058
1103
|
setattr(param, 'micro_step', 0)
|