mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -21,16 +21,15 @@ from datetime import datetime
|
|
|
21
21
|
|
|
22
22
|
import pytz
|
|
23
23
|
import mindspore as ms
|
|
24
|
-
|
|
25
|
-
from mindspore import Tensor, ops, mint
|
|
24
|
+
from mindspore import Tensor, mint
|
|
26
25
|
from mindspore import nn, _no_grad
|
|
27
26
|
from mindspore.communication import get_rank
|
|
28
27
|
|
|
29
28
|
from msprobe.core.common.log import logger
|
|
30
29
|
from msprobe.core.common.const import MonitorConst
|
|
31
|
-
from msprobe.core.common.file_utils import load_json
|
|
30
|
+
from msprobe.core.common.file_utils import load_json, save_json
|
|
32
31
|
from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
|
|
33
|
-
is_skip_step, get_metrics, get_single_metrics
|
|
32
|
+
is_skip_step, get_metrics, get_single_metrics, get_target_output_dir
|
|
34
33
|
from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
|
|
35
34
|
from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
|
|
36
35
|
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
@@ -108,6 +107,10 @@ class ModuleHookContext:
|
|
|
108
107
|
elif key_name in ['input', 'input_grad']:
|
|
109
108
|
self.ignore_in = True
|
|
110
109
|
|
|
110
|
+
def reset(self):
|
|
111
|
+
self.actv.clear()
|
|
112
|
+
self.actvgrad.clear()
|
|
113
|
+
|
|
111
114
|
|
|
112
115
|
start_step = 0
|
|
113
116
|
|
|
@@ -116,7 +119,6 @@ start_step = 0
|
|
|
116
119
|
class OptimizerContext:
|
|
117
120
|
def __init__(self) -> None:
|
|
118
121
|
self.step = start_step
|
|
119
|
-
self.param_effective_rank = defaultdict(float)
|
|
120
122
|
self.param_mg_direction = defaultdict(float)
|
|
121
123
|
self.param_adam_update = defaultdict()
|
|
122
124
|
self.param_adam_ratio = defaultdict()
|
|
@@ -131,6 +133,7 @@ class OptimizerContext:
|
|
|
131
133
|
def reset(self) -> None:
|
|
132
134
|
self.param_mg_direction.clear()
|
|
133
135
|
self.param_adam_update.clear()
|
|
136
|
+
self.param_adam_ratio.clear()
|
|
134
137
|
self.param_weight_grad.clear()
|
|
135
138
|
self.param_exp_avg.clear()
|
|
136
139
|
self.exp_avg_metric.clear()
|
|
@@ -179,50 +182,100 @@ class CommunicationContext:
|
|
|
179
182
|
|
|
180
183
|
class TrainerMon:
|
|
181
184
|
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
|
|
185
|
+
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
186
|
+
self.config_file_path = config_file_path
|
|
187
|
+
self.process_group = process_group
|
|
188
|
+
self.params_have_main_grad = params_have_main_grad
|
|
189
|
+
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
190
|
+
self.config = load_json(config_file_path)
|
|
191
|
+
validate_config(self.config)
|
|
192
|
+
|
|
193
|
+
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
194
|
+
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
195
|
+
self.unique_id = str(uuid.uuid4())[:8]
|
|
196
|
+
self.output_base_dir = get_output_base_dir()
|
|
197
|
+
time_tags = self.config.get("append_output", [])
|
|
198
|
+
try:
|
|
199
|
+
self.rank = get_rank()
|
|
200
|
+
if time_tags:
|
|
201
|
+
output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
|
|
202
|
+
if str(self.rank) in output_append_dirs:
|
|
203
|
+
self.tensorboard_dir = output_append_dirs[str(self.rank)]
|
|
204
|
+
logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}")
|
|
205
|
+
else:
|
|
206
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir,
|
|
207
|
+
f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
208
|
+
except Exception as e:
|
|
209
|
+
self.rank = 0
|
|
210
|
+
self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
|
|
211
|
+
|
|
212
|
+
self.pp_stage = 0
|
|
213
|
+
self.group_mates = [0]
|
|
214
|
+
|
|
215
|
+
# TYPE2: 只会在set_monitor()主调中赋值的变量
|
|
216
|
+
self.model = None
|
|
217
|
+
self.vpp = False
|
|
218
|
+
self.dp_group = None
|
|
219
|
+
self.tp_group = None
|
|
220
|
+
self.micro_batch_number = 1
|
|
221
|
+
|
|
222
|
+
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
182
223
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
183
224
|
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
184
225
|
self.optimizer_context = defaultdict(OptimizerContext)
|
|
185
226
|
self.cc_context = defaultdict(CommunicationContext)
|
|
186
227
|
self.grad_context = GradContext()
|
|
187
|
-
self.params_have_main_grad = params_have_main_grad
|
|
188
228
|
self.handles = defaultdict(list)
|
|
189
|
-
self.
|
|
190
|
-
|
|
229
|
+
self.param2name = defaultdict(str)
|
|
230
|
+
self.name2index = defaultdict()
|
|
231
|
+
self.name2indices = defaultdict()
|
|
232
|
+
self.name2param = {}
|
|
233
|
+
self.duplicate_param = {}
|
|
234
|
+
self.name2tag = {}
|
|
235
|
+
self.param_name_call_id = {}
|
|
236
|
+
self.call_id = 0
|
|
237
|
+
self.module_struct = defaultdict(dict)
|
|
238
|
+
self.grad_accs = []
|
|
239
|
+
self.weight_hooked = False
|
|
240
|
+
self.optimizer_hooked = False
|
|
241
|
+
self.param_registered = False
|
|
242
|
+
self.struct_printed = False
|
|
243
|
+
|
|
244
|
+
# 动静态区分
|
|
245
|
+
self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
|
|
246
|
+
if self.dynamic_enable:
|
|
247
|
+
logger.warning(f"DYNAMIC_MONITOR is set, "
|
|
248
|
+
f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
|
|
249
|
+
self.monitoring = False
|
|
250
|
+
else:
|
|
251
|
+
self.set_config()
|
|
252
|
+
# 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
|
|
253
|
+
if self.collect_times > 0:
|
|
254
|
+
self.monitoring = True
|
|
191
255
|
|
|
256
|
+
def set_config(self):
|
|
192
257
|
self.start_step = self.config.get("start_step", 0)
|
|
193
258
|
self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
|
|
194
259
|
self.step_interval = self.config.get("step_interval", 1)
|
|
195
|
-
self.has_collect_times = 0
|
|
196
|
-
|
|
197
|
-
# monitor target in module, such as layer, weight, grad
|
|
260
|
+
self.has_collect_times = 0 # 重设采集计数器
|
|
261
|
+
self.print_struct = self.config.get("print_struct", False)
|
|
198
262
|
self.targets = self.config.get("targets", None)
|
|
199
263
|
self.is_select = self.config.get("is_select", False)
|
|
200
264
|
self.module_rank_list = self.config.get("module_ranks", [])
|
|
201
|
-
# only csv supported in mindspore
|
|
202
|
-
self.format = self.config.get('format', MonitorConst.CSV)
|
|
265
|
+
self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
|
|
203
266
|
self.eps = self.config.get('eps', 1e-8)
|
|
204
|
-
# monitor mean/max/norm/min/nan...
|
|
205
|
-
self.ops = self.config.get('ops', [])
|
|
267
|
+
self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan...
|
|
206
268
|
self.ndigits = self.config.get('ndigits', 6)
|
|
207
269
|
self.all_xy = self.config.get('all_xy', False)
|
|
208
|
-
# module input/output input_grad/output_grad
|
|
209
270
|
self.xy_distribution = self.config.get('xy_distribution', False)
|
|
210
|
-
# activation forward
|
|
211
271
|
self.forward_only = self.config.get('forward_only', False)
|
|
212
|
-
# activation backward
|
|
213
272
|
self.backward_only = self.config.get('backward_only', False)
|
|
214
|
-
#
|
|
215
|
-
self.
|
|
216
|
-
# m/v of adam
|
|
217
|
-
self.mv_distribution = self.config.get("mv_distribution", False)
|
|
218
|
-
# weight grad
|
|
273
|
+
self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam
|
|
274
|
+
self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam
|
|
219
275
|
self.wg_distribution = self.config.get("wg_distribution", False)
|
|
220
|
-
# optimizer param
|
|
221
276
|
self.param_distribution = self.config.get("param_distribution", False)
|
|
222
|
-
# main grad direction
|
|
223
|
-
self.
|
|
224
|
-
# communication ops
|
|
225
|
-
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
277
|
+
self.mg_direction = self.config.get('mg_direction', False) # main grad direction
|
|
278
|
+
self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
|
|
226
279
|
if not self.cc_distribution.get('enable', False):
|
|
227
280
|
self.cc_log_only = False
|
|
228
281
|
else:
|
|
@@ -230,140 +283,173 @@ class TrainerMon:
|
|
|
230
283
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
231
284
|
self.cc_logged_stack = defaultdict(set)
|
|
232
285
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
233
|
-
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
234
|
-
api_register.redirect_api()
|
|
235
286
|
self.common_info()
|
|
236
287
|
|
|
288
|
+
# 初始化AnomalyData工厂
|
|
237
289
|
alert_setting = self.config.get('alert', {"rules": []})
|
|
238
290
|
self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
|
|
239
|
-
|
|
240
|
-
local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
|
|
241
|
-
|
|
242
|
-
cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
|
|
243
|
-
unique_id = str(uuid.uuid4())[:8]
|
|
244
|
-
output_base_dir = get_output_base_dir()
|
|
245
|
-
|
|
246
|
-
time_tags = self.config.get("append_output", [])
|
|
247
|
-
if time_tags:
|
|
248
|
-
output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1])
|
|
249
|
-
try:
|
|
250
|
-
rank = get_rank()
|
|
251
|
-
except Exception as e:
|
|
252
|
-
rank = 0
|
|
253
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}")
|
|
254
|
-
logger.error(f"Failed to get rank, setting tensorboard_dir to {tensorboard_dir}")
|
|
255
|
-
pp_stage = 0
|
|
256
|
-
group_mates = [0]
|
|
257
|
-
else:
|
|
258
|
-
if time_tags and str(rank) in output_append_dirs:
|
|
259
|
-
tensorboard_dir = outputappenddirs[str(rank)]
|
|
260
|
-
logger.info(f"Append rank({rank}) result to {tensorboard_dir}")
|
|
261
|
-
else:
|
|
262
|
-
tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}")
|
|
263
|
-
pp_stage = 0
|
|
264
|
-
group_mates = [0]
|
|
265
|
-
|
|
266
|
-
self.rank = rank
|
|
267
|
-
|
|
268
|
-
# 初始化AnomalyData工厂
|
|
269
291
|
self.anomaly_data_factory = None
|
|
270
292
|
if alert_setting.get('dump', False):
|
|
271
|
-
self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates)
|
|
293
|
+
self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
|
|
272
294
|
|
|
295
|
+
# 初始化writer, 创建输出目录
|
|
273
296
|
if self.format not in FORMAT_MAPPING:
|
|
274
297
|
logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
275
298
|
self.format = MonitorConst.CSV
|
|
276
299
|
writer = FORMAT_MAPPING[self.format]
|
|
277
300
|
self.step_count_per_record = self.config.get('step_count_per_record', 1)
|
|
278
|
-
|
|
279
301
|
self.summary_writer = writer(
|
|
280
302
|
WriterInput(
|
|
281
|
-
tensorboard_dir,
|
|
303
|
+
self.tensorboard_dir,
|
|
282
304
|
self.alert_rules,
|
|
283
|
-
unique_id,
|
|
305
|
+
self.unique_id,
|
|
284
306
|
self.anomaly_data_factory,
|
|
285
307
|
self.ndigits,
|
|
286
308
|
self.step_count_per_record
|
|
287
309
|
)
|
|
288
310
|
)
|
|
289
311
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
self.
|
|
294
|
-
|
|
295
|
-
self.
|
|
296
|
-
|
|
297
|
-
self.
|
|
298
|
-
|
|
299
|
-
self.
|
|
300
|
-
|
|
301
|
-
self.
|
|
302
|
-
|
|
303
|
-
self.
|
|
304
|
-
|
|
305
|
-
self.param_name_call_id = {}
|
|
306
|
-
self.duplicate_param = {}
|
|
307
|
-
self.name2tag = {}
|
|
308
|
-
self.call_id = 0
|
|
309
|
-
self.grad_accs = []
|
|
310
|
-
self.handles = defaultdict(list)
|
|
311
|
-
|
|
312
|
-
self.print_struct = self.config.get("print_struct", False)
|
|
313
|
-
self.struct_printed = False
|
|
314
|
-
self.module_struct = defaultdict(dict)
|
|
312
|
+
def common_info(self):
|
|
313
|
+
if not self.xy_distribution:
|
|
314
|
+
logger.info("> module input/output input_grad/output_grad is not monitored. ")
|
|
315
|
+
if self.forward_only:
|
|
316
|
+
logger.info("> only module forward is monitored. ")
|
|
317
|
+
if not self.ur_distribution:
|
|
318
|
+
logger.info("> update vector and ratio vector of adam is not monitored. ")
|
|
319
|
+
if not self.mv_distribution:
|
|
320
|
+
logger.info("> momentum and variance of adam is not monitored. ")
|
|
321
|
+
if not self.wg_distribution:
|
|
322
|
+
logger.info("> weight grad of specified module is not monitored. ")
|
|
323
|
+
if not self.mg_direction:
|
|
324
|
+
logger.info('> grad and momentum direction will not be compared.')
|
|
325
|
+
if not self.cc_distribution.get('enable', False):
|
|
326
|
+
logger.info("> cc operator is not monitored.")
|
|
315
327
|
|
|
316
|
-
# Start
|
|
317
328
|
def set_monitor(
|
|
318
329
|
self,
|
|
319
330
|
model,
|
|
331
|
+
optimizer,
|
|
320
332
|
grad_acc_steps=1,
|
|
321
|
-
optimizer=None,
|
|
322
333
|
tp_group=None,
|
|
323
334
|
dp_group=None,
|
|
324
|
-
start_iteration=0
|
|
335
|
+
start_iteration=0
|
|
336
|
+
):
|
|
325
337
|
global start_step
|
|
326
338
|
start_step = start_iteration
|
|
327
|
-
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
328
|
-
self.hook_optimizer(optimizer)
|
|
329
339
|
self.micro_batch_number = grad_acc_steps
|
|
330
340
|
self.dp_group = dp_group
|
|
331
341
|
self.tp_group = tp_group
|
|
342
|
+
self.hook_step_final(optimizer)
|
|
343
|
+
if not isinstance(model, list):
|
|
344
|
+
model = [model]
|
|
345
|
+
self.model = model
|
|
346
|
+
if len(model) > 1:
|
|
347
|
+
self.vpp = True
|
|
348
|
+
logger.info('vpp enabled')
|
|
349
|
+
if not self.dynamic_enable:
|
|
350
|
+
self.register_hooks(optimizer)
|
|
351
|
+
|
|
352
|
+
def hook_step_final(self, optimizer):
|
|
353
|
+
def step_final_hook(optimizer, *args, **kwargs):
|
|
354
|
+
context = self.optimizer_context[optimizer]
|
|
355
|
+
# 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
|
|
356
|
+
if self.monitoring:
|
|
357
|
+
module_rank_valid = self.is_target_rank()
|
|
358
|
+
step_condition = (context.step >= self.start_step and (
|
|
359
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
360
|
+
if module_rank_valid and step_condition:
|
|
361
|
+
self.has_collect_times += 1
|
|
362
|
+
self.write_xy_tb(context.step)
|
|
363
|
+
self.write_grad_tb(context.step)
|
|
364
|
+
self.write_mv_tb(context)
|
|
365
|
+
self.write_param_tb(context)
|
|
366
|
+
|
|
367
|
+
if context.metric_dict:
|
|
368
|
+
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
369
|
+
context.metric_dict.clear()
|
|
370
|
+
|
|
371
|
+
self.summary_writer.clear_anomalies()
|
|
372
|
+
self.call_id = 0
|
|
373
|
+
self.param_name_call_id.clear()
|
|
374
|
+
|
|
375
|
+
if self.has_collect_times >= self.collect_times:
|
|
376
|
+
self._remove_all_hooks_final(optimizer)
|
|
332
377
|
|
|
333
|
-
|
|
334
|
-
|
|
378
|
+
context.step += 1
|
|
379
|
+
self.dynamic_monitor(optimizer)
|
|
335
380
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
def
|
|
340
|
-
|
|
341
|
-
|
|
381
|
+
optimizer.register_forward_hook(step_final_hook)
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
def dynamic_monitor(self, optimizer):
|
|
385
|
+
"""
|
|
386
|
+
If dynamic monitor enabled and config.json updated,
|
|
387
|
+
remove hooks and register new hooks according to new configuration.
|
|
388
|
+
"""
|
|
389
|
+
context = self.optimizer_context[optimizer]
|
|
390
|
+
if not self.dynamic_enable:
|
|
342
391
|
return
|
|
392
|
+
try:
|
|
393
|
+
# 如果文件时间戳没变, 可以不读取节省时间
|
|
394
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
395
|
+
if config_timestamp == self.config_timestamp:
|
|
396
|
+
return
|
|
397
|
+
# 更新config文件最新修改时间戳
|
|
398
|
+
self.config_timestamp = config_timestamp
|
|
399
|
+
config = load_json(self.config_file_path)
|
|
400
|
+
except Exception as e:
|
|
401
|
+
logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
if config.get("dynamic_on", False):
|
|
405
|
+
try:
|
|
406
|
+
validate_config(config)
|
|
407
|
+
self.config = config
|
|
408
|
+
self.set_config()
|
|
409
|
+
self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
|
|
410
|
+
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
411
|
+
f"will start new hook at step{context.step}.")
|
|
412
|
+
except Exception as e:
|
|
413
|
+
logger.error(f"set config wrong because {e}, not updated, please check!!!")
|
|
414
|
+
return
|
|
415
|
+
|
|
416
|
+
self._remove_all_hooks()
|
|
417
|
+
self.register_hooks(optimizer)
|
|
418
|
+
|
|
419
|
+
def register_hooks(self, optimizer):
|
|
420
|
+
self._register_param_name()
|
|
421
|
+
self.hook_modules()
|
|
422
|
+
self.hook_optimizer(optimizer)
|
|
423
|
+
self._patch_grad_sync()
|
|
424
|
+
if self.cc_distribution.get('enable', False):
|
|
425
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
426
|
+
api_register.redirect_api()
|
|
427
|
+
self.monitoring = True
|
|
343
428
|
|
|
429
|
+
def hook_modules(self):
|
|
344
430
|
if not self.is_target_rank():
|
|
345
431
|
return
|
|
432
|
+
module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
346
433
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
for param in optimizer.get_parameters():
|
|
352
|
-
if MonitorConst.EXP_AVG_SQ in param.name:
|
|
353
|
-
v_list.append(param)
|
|
354
|
-
elif MonitorConst.EXP_AVG in param.name:
|
|
355
|
-
m_list.append(param)
|
|
356
|
-
else:
|
|
357
|
-
param_list.append(param)
|
|
358
|
-
grad_names.append(param.name)
|
|
434
|
+
for key in module_in_all_stage:
|
|
435
|
+
struct = self.targets.pop(key)
|
|
436
|
+
self.targets.update(
|
|
437
|
+
{f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
359
438
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
439
|
+
hooked_count = 0
|
|
440
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
441
|
+
if not isinstance(model_chunk, nn.Cell):
|
|
442
|
+
logger.info("Target Model is not Cell")
|
|
443
|
+
continue
|
|
444
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
445
|
+
targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
|
|
446
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
447
|
+
logger.info(f"> {hooked_count} modules are monitored.")
|
|
448
|
+
|
|
449
|
+
def hook_optimizer(self, optimizer):
|
|
364
450
|
def optimizer_pre_hook_function(opt, grad_names, gradients):
|
|
365
451
|
context = self.optimizer_context[opt]
|
|
366
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
452
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
367
453
|
self.collect_times):
|
|
368
454
|
return
|
|
369
455
|
gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
|
|
@@ -402,46 +488,64 @@ class TrainerMon:
|
|
|
402
488
|
context.metric_dict = metric_dict
|
|
403
489
|
return
|
|
404
490
|
|
|
405
|
-
def optimizer_post_hook_function(opt, args, gradients, outputs):
|
|
406
|
-
context = self.optimizer_context[opt]
|
|
407
|
-
step_skip = is_skip_step(context.step, self.start_step, self.step_interval, \
|
|
408
|
-
self.has_collect_times, self.collect_times)
|
|
409
|
-
if step_skip:
|
|
410
|
-
context.step += 1
|
|
411
|
-
return
|
|
412
|
-
self.write_xy_tb(context.step)
|
|
413
|
-
self.write_grad_tb(context.step)
|
|
414
|
-
self.write_mv_tb(context)
|
|
415
|
-
self.write_param_tb(context)
|
|
416
|
-
|
|
417
|
-
if context.metric_dict:
|
|
418
|
-
self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
|
|
419
|
-
context.metric_dict.clear()
|
|
420
|
-
self.has_collect_times += 1
|
|
421
|
-
context.step += 1
|
|
422
|
-
if self.anomaly_data_factory:
|
|
423
|
-
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
424
|
-
self.summary_writer.clear_anomalies()
|
|
425
|
-
self.call_id = 0
|
|
426
|
-
self.param_name_call_id.clear()
|
|
427
|
-
return
|
|
428
|
-
|
|
429
491
|
def optimizer_pre_hook_wrapper(func, grad_names):
|
|
430
492
|
def wrapper(opt, gradients):
|
|
431
493
|
return func(opt, grad_names, gradients)
|
|
432
494
|
return wrapper
|
|
433
495
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
return func(opt, args, gradients, outputs)
|
|
437
|
-
return wrapper
|
|
496
|
+
if self.optimizer_hooked or not self.is_target_rank():
|
|
497
|
+
return
|
|
438
498
|
|
|
439
|
-
|
|
440
|
-
|
|
499
|
+
m_list = []
|
|
500
|
+
v_list = []
|
|
501
|
+
param_list = []
|
|
502
|
+
grad_names = []
|
|
503
|
+
for param in optimizer.get_parameters():
|
|
504
|
+
if MonitorConst.EXP_AVG_SQ in param.name:
|
|
505
|
+
v_list.append(param)
|
|
506
|
+
elif MonitorConst.EXP_AVG in param.name:
|
|
507
|
+
m_list.append(param)
|
|
508
|
+
elif param.name in ['global_step', 'learning_rate']:
|
|
509
|
+
pass
|
|
510
|
+
else:
|
|
511
|
+
param_list.append(param)
|
|
512
|
+
grad_names.append(param.name)
|
|
441
513
|
|
|
514
|
+
handle = optimizer.register_forward_pre_hook(
|
|
515
|
+
optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
|
|
516
|
+
self.handles['optimizer'].append(handle)
|
|
442
517
|
self.optimizer_hooked = True
|
|
443
518
|
return
|
|
444
519
|
|
|
520
|
+
def generate_wgrad_metrics(self):
|
|
521
|
+
if not self.wg_distribution:
|
|
522
|
+
return {}, {}
|
|
523
|
+
|
|
524
|
+
if self.weight_hooked:
|
|
525
|
+
try:
|
|
526
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
527
|
+
except Exception as e:
|
|
528
|
+
logger.warning(f"An error occurred while generating wgrad pre metrics")
|
|
529
|
+
return {}, {}
|
|
530
|
+
|
|
531
|
+
grad_dict = {}
|
|
532
|
+
for param, name in self.param2name.items():
|
|
533
|
+
if self.duplicate_param.get(name, False):
|
|
534
|
+
continue
|
|
535
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
536
|
+
if grad is None:
|
|
537
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
538
|
+
continue
|
|
539
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
540
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
541
|
+
grad_dict[tag] = grad
|
|
542
|
+
try:
|
|
543
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
544
|
+
except Exception as e:
|
|
545
|
+
logger.warning(f"An error occurred while generating wgrad post metrics")
|
|
546
|
+
return {}, {}
|
|
547
|
+
return self.grad_context.post, self.grad_context.pre
|
|
548
|
+
|
|
445
549
|
def write_xy_tb(self, step):
|
|
446
550
|
if not self.xy_distribution:
|
|
447
551
|
return
|
|
@@ -468,121 +572,27 @@ class TrainerMon:
|
|
|
468
572
|
if not self.wg_distribution:
|
|
469
573
|
return
|
|
470
574
|
|
|
471
|
-
|
|
472
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced')
|
|
473
|
-
else:
|
|
474
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
575
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
475
576
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
476
577
|
|
|
477
|
-
def common_info(self):
|
|
478
|
-
if not self.xy_distribution:
|
|
479
|
-
logger.info("> module input/output input_grad/output_grad is not monitored. ")
|
|
480
|
-
if self.forward_only:
|
|
481
|
-
logger.info("> only module forward is monitored. ")
|
|
482
|
-
if not self.ur_distribution:
|
|
483
|
-
logger.info("> update vector and ratio vector of adam is not monitored. ")
|
|
484
|
-
if not self.mv_distribution:
|
|
485
|
-
logger.info("> momentum and variance of adam is not monitored. ")
|
|
486
|
-
if not self.wg_distribution:
|
|
487
|
-
logger.info("> weight grad of specified module is not monitored. ")
|
|
488
|
-
if not self.mg_direction:
|
|
489
|
-
logger.info('> grad and momentum direction will not be compared.')
|
|
490
|
-
if not self.cc_distribution.get('enable', False):
|
|
491
|
-
logger.info("> cc operator is not monitored.")
|
|
492
|
-
|
|
493
578
|
def is_target_rank(self):
|
|
494
|
-
|
|
495
|
-
if self.module_rank_list and (rank_id not in self.module_rank_list):
|
|
579
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
496
580
|
return False
|
|
497
581
|
return True
|
|
498
582
|
|
|
499
|
-
def hook_modules(self, model, grad_acc_steps):
|
|
500
|
-
if not self.is_target_rank():
|
|
501
|
-
return
|
|
502
|
-
if not isinstance(model, list):
|
|
503
|
-
model = [model]
|
|
504
|
-
self.model = model # list
|
|
505
|
-
self._register_param_name(model)
|
|
506
|
-
self.micro_batch_number = grad_acc_steps
|
|
507
|
-
module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
508
|
-
|
|
509
|
-
for key in module_in_all_stage:
|
|
510
|
-
struct = self.targets.pop(key)
|
|
511
|
-
self.targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(model))})
|
|
512
|
-
|
|
513
|
-
hooked_count = 0
|
|
514
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
515
|
-
if not isinstance(model_chunk, nn.Cell):
|
|
516
|
-
logger.info("Target Model is not Cell")
|
|
517
|
-
continue
|
|
518
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
519
|
-
targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
|
|
520
|
-
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
521
|
-
logger.info(f"> {hooked_count} modules are monitored.")
|
|
522
|
-
|
|
523
583
|
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
524
|
-
rank_id = str(get_rank())
|
|
525
584
|
metrics = {}
|
|
526
|
-
key = get_summary_writer_tag_name(module_name, tag,
|
|
585
|
+
key = get_summary_writer_tag_name(module_name, tag, str(self.rank))
|
|
527
586
|
if isinstance(tensor, Tensor):
|
|
528
587
|
self._register_param_call_id("_hook_module", key)
|
|
529
588
|
metrics[key] = tensor
|
|
530
589
|
return metrics
|
|
531
590
|
|
|
532
|
-
def
|
|
533
|
-
|
|
534
|
-
return {}, {}
|
|
535
|
-
|
|
536
|
-
if self.weight_hooked:
|
|
537
|
-
try:
|
|
538
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
539
|
-
except Exception as e:
|
|
540
|
-
logger.warning(f"An error occurred while generating wgrad pre metrics")
|
|
541
|
-
return {}, {}
|
|
542
|
-
|
|
543
|
-
grad_dict = {}
|
|
544
|
-
for param, name in self.param2name.items():
|
|
545
|
-
if self.duplicate_param.get(name, False):
|
|
546
|
-
continue
|
|
547
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
548
|
-
if grad is None:
|
|
549
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
550
|
-
continue
|
|
551
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
552
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
553
|
-
grad_dict[tag] = grad
|
|
554
|
-
try:
|
|
555
|
-
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
556
|
-
except Exception as e:
|
|
557
|
-
logger.warning(f"An error occurred while generating wgrad post metrics")
|
|
558
|
-
return {}, {}
|
|
559
|
-
return self.grad_context.post, self.grad_context.pre
|
|
560
|
-
|
|
561
|
-
def _register_param_name(self, model):
|
|
562
|
-
if self.param_registered:
|
|
563
|
-
return
|
|
564
|
-
|
|
565
|
-
if len(model) > 1:
|
|
566
|
-
self.vpp = True
|
|
567
|
-
logger.info('vpp enabled')
|
|
568
|
-
|
|
569
|
-
for vpp_stage, model_chunk in enumerate(model):
|
|
591
|
+
def _register_param_name(self):
|
|
592
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
570
593
|
prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
571
594
|
self._register_chunk(model_chunk, prefix)
|
|
572
595
|
|
|
573
|
-
self.param_registered = True
|
|
574
|
-
|
|
575
|
-
def _is_target_param(self, param_name, param, prefix):
|
|
576
|
-
if not self.targets:
|
|
577
|
-
return True
|
|
578
|
-
squash_name = prefix + squash_param_name(param_name)
|
|
579
|
-
name = prefix + param_name
|
|
580
|
-
for target in self.targets.keys():
|
|
581
|
-
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
582
|
-
setattr(param, "zero_out_wgrad", True)
|
|
583
|
-
return True
|
|
584
|
-
return False
|
|
585
|
-
|
|
586
596
|
def _register_chunk(self, model_chunk, prefix):
|
|
587
597
|
index = 0
|
|
588
598
|
for param in model_chunk.get_parameters():
|
|
@@ -607,17 +617,6 @@ class TrainerMon:
|
|
|
607
617
|
}
|
|
608
618
|
index += 1
|
|
609
619
|
|
|
610
|
-
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
611
|
-
if self.all_xy or self.print_struct:
|
|
612
|
-
return vpp_stage + squash_param_name(module_name)
|
|
613
|
-
for pattern in [
|
|
614
|
-
vpp_stage + squash_param_name(module_name),
|
|
615
|
-
vpp_stage + module_name,
|
|
616
|
-
]:
|
|
617
|
-
if pattern in targets:
|
|
618
|
-
return pattern
|
|
619
|
-
return ""
|
|
620
|
-
|
|
621
620
|
def _hook_module(self, target_names, module, vpp_stage=''):
|
|
622
621
|
if not isinstance(module, nn.Cell):
|
|
623
622
|
# nothing to hook
|
|
@@ -637,7 +636,7 @@ class TrainerMon:
|
|
|
637
636
|
return
|
|
638
637
|
if not module.training:
|
|
639
638
|
return
|
|
640
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
639
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
641
640
|
self.collect_times):
|
|
642
641
|
step_accumulates_one(context, self.micro_batch_number)
|
|
643
642
|
return
|
|
@@ -685,7 +684,7 @@ class TrainerMon:
|
|
|
685
684
|
self.module_struct[context.module_name].update(context.struct)
|
|
686
685
|
return
|
|
687
686
|
|
|
688
|
-
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
687
|
+
if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
|
|
689
688
|
self.collect_times):
|
|
690
689
|
step_accumulates_one(context, self.micro_batch_number)
|
|
691
690
|
return
|
|
@@ -752,50 +751,10 @@ class TrainerMon:
|
|
|
752
751
|
hooked_count += 1
|
|
753
752
|
return hooked_count
|
|
754
753
|
|
|
755
|
-
def _register_param_call_id(self, hook_name: str, key: str):
|
|
756
|
-
"""
|
|
757
|
-
:param hook_name:
|
|
758
|
-
:param key: str, '0:relu_0/output_grad'
|
|
759
|
-
:return:
|
|
760
|
-
"""
|
|
761
|
-
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
762
|
-
self.param_name_call_id[key] = self.call_id
|
|
763
|
-
self.call_id += 1
|
|
764
|
-
|
|
765
754
|
def _patch_grad_sync(self):
|
|
766
|
-
# mindspore 暂不使用megatron
|
|
767
|
-
def patch_sync(sync_grad_func):
|
|
768
|
-
def wrapper(bucket):
|
|
769
|
-
grad_dict = {}
|
|
770
|
-
for param, name in self.param2name.items():
|
|
771
|
-
if param not in bucket.params_list:
|
|
772
|
-
continue
|
|
773
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
774
|
-
if grad is None:
|
|
775
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
776
|
-
continue
|
|
777
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
778
|
-
if tag is None:
|
|
779
|
-
continue
|
|
780
|
-
grad_dict[tag] = grad
|
|
781
|
-
try:
|
|
782
|
-
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
783
|
-
except Exception as e:
|
|
784
|
-
logger.warning(f"An error occurred while generating weight grad metrics")
|
|
785
|
-
out = sync_grad_func(bucket)
|
|
786
|
-
return out
|
|
787
|
-
|
|
788
|
-
return wrapper
|
|
789
|
-
|
|
790
|
-
self.enable_megatron = False
|
|
791
|
-
|
|
792
755
|
if not self.wg_distribution:
|
|
793
756
|
return
|
|
794
|
-
|
|
795
|
-
if self.enable_megatron:
|
|
796
|
-
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version
|
|
797
|
-
else:
|
|
798
|
-
self._hook_weights()
|
|
757
|
+
self._hook_weights()
|
|
799
758
|
|
|
800
759
|
def _hook_weights(self):
|
|
801
760
|
context = self.grad_context
|
|
@@ -819,3 +778,96 @@ class TrainerMon:
|
|
|
819
778
|
handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
|
|
820
779
|
self.handles['wgrads'].append(handle)
|
|
821
780
|
self.weight_hooked = True
|
|
781
|
+
|
|
782
|
+
def _is_target_param(self, param_name, param, prefix):
|
|
783
|
+
if not self.targets:
|
|
784
|
+
return True
|
|
785
|
+
squash_name = prefix + squash_param_name(param_name)
|
|
786
|
+
name = prefix + param_name
|
|
787
|
+
for target in self.targets.keys():
|
|
788
|
+
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
789
|
+
setattr(param, "zero_out_wgrad", True)
|
|
790
|
+
return True
|
|
791
|
+
return False
|
|
792
|
+
|
|
793
|
+
def _is_target_module(self, module_name, targets, vpp_stage):
|
|
794
|
+
if self.all_xy or self.print_struct:
|
|
795
|
+
return vpp_stage + squash_param_name(module_name)
|
|
796
|
+
for pattern in [
|
|
797
|
+
vpp_stage + squash_param_name(module_name),
|
|
798
|
+
vpp_stage + module_name,
|
|
799
|
+
]:
|
|
800
|
+
if pattern in targets:
|
|
801
|
+
return pattern
|
|
802
|
+
return ""
|
|
803
|
+
|
|
804
|
+
def _register_param_call_id(self, hook_name: str, key: str):
|
|
805
|
+
"""
|
|
806
|
+
:param hook_name:
|
|
807
|
+
:param key: str, '0:relu_0/output_grad'
|
|
808
|
+
:return:
|
|
809
|
+
"""
|
|
810
|
+
logger.debug(f"{hook_name} {key}: {self.call_id}")
|
|
811
|
+
self.param_name_call_id[key] = self.call_id
|
|
812
|
+
self.call_id += 1
|
|
813
|
+
|
|
814
|
+
def _remove_all_hooks(self):
|
|
815
|
+
# 清空hook handle
|
|
816
|
+
for handle in self.handles['xy']:
|
|
817
|
+
handle.remove()
|
|
818
|
+
self.handles['xy'].clear()
|
|
819
|
+
# 清空对应context缓存
|
|
820
|
+
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
821
|
+
fwd_context.reset()
|
|
822
|
+
for _, bwd_context in self.module_bwd_hook_context_by_module.items():
|
|
823
|
+
bwd_context.reset()
|
|
824
|
+
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
825
|
+
|
|
826
|
+
for handle in self.handles['wgrads']:
|
|
827
|
+
handle.remove()
|
|
828
|
+
self.handles['wgrads'].clear()
|
|
829
|
+
self.weight_hooked = False
|
|
830
|
+
|
|
831
|
+
if self.optimizer_hooked:
|
|
832
|
+
for handle in self.handles['optimizer']:
|
|
833
|
+
handle.remove()
|
|
834
|
+
self.handles['optimizer'].clear()
|
|
835
|
+
for _, context in self.optimizer_context.items():
|
|
836
|
+
context.reset()
|
|
837
|
+
self.optimizer_hooked = False
|
|
838
|
+
|
|
839
|
+
for handle in self.handles['cc']:
|
|
840
|
+
handle.remove()
|
|
841
|
+
self.handles['cc'].clear()
|
|
842
|
+
api_register.restore_api()
|
|
843
|
+
for _, context in self.cc_context.items():
|
|
844
|
+
context.reset()
|
|
845
|
+
|
|
846
|
+
# 清空节点缓存
|
|
847
|
+
self.param2name.clear()
|
|
848
|
+
self.name2index.clear()
|
|
849
|
+
self.name2indices.clear()
|
|
850
|
+
self.name2param.clear()
|
|
851
|
+
self.duplicate_param.clear()
|
|
852
|
+
self.name2tag.clear()
|
|
853
|
+
self.module_struct.clear()
|
|
854
|
+
self.grad_accs.clear()
|
|
855
|
+
|
|
856
|
+
# 关闭采集状态
|
|
857
|
+
self.monitoring = False
|
|
858
|
+
|
|
859
|
+
def _remove_all_hooks_final(self, optimizer):
|
|
860
|
+
if self.dynamic_enable:
|
|
861
|
+
# 结束后自动重置dynamic_on为False等待用户手动开启
|
|
862
|
+
try:
|
|
863
|
+
config = load_json(self.config_file_path)
|
|
864
|
+
config['dynamic_on'] = False
|
|
865
|
+
save_json(self.config_file_path, config, indent=2)
|
|
866
|
+
config_timestamp = os.path.getmtime(self.config_file_path)
|
|
867
|
+
self.config_timestamp = config_timestamp
|
|
868
|
+
logger.info(
|
|
869
|
+
"Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
|
|
870
|
+
except Exception as e:
|
|
871
|
+
logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
|
|
872
|
+
logger.info("Finish monitor")
|
|
873
|
+
self._remove_all_hooks()
|