mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -22,13 +22,15 @@ from torch.utils.tensorboard import SummaryWriter
|
|
|
22
22
|
from tqdm import tqdm
|
|
23
23
|
|
|
24
24
|
from msprobe.core.common.const import MonitorConst
|
|
25
|
-
from msprobe.core.common.file_utils import read_csv, create_directory, remove_path
|
|
25
|
+
from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod
|
|
26
26
|
from msprobe.core.common.utils import is_int
|
|
27
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
27
28
|
from msprobe.pytorch.common.log import logger
|
|
28
29
|
from msprobe.pytorch.monitor.utils import get_target_output_dir
|
|
29
30
|
|
|
30
31
|
all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
|
|
31
32
|
CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
|
|
33
|
+
MAX_PROCESS_NUM = 128
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
def parse_step_line(line, ops):
|
|
@@ -76,6 +78,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
|
|
|
76
78
|
writer.add_scalar(tag, value, step)
|
|
77
79
|
|
|
78
80
|
|
|
81
|
+
@recursion_depth_decorator("update_dict", max_depth=50)
|
|
79
82
|
def update_dict(dict1, dict2):
|
|
80
83
|
for key, value in dict2.items():
|
|
81
84
|
if key in dict1:
|
|
@@ -115,11 +118,13 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
|
|
|
115
118
|
def check_process_num(process_num):
|
|
116
119
|
if not is_int(process_num) or process_num <= 0:
|
|
117
120
|
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
121
|
+
if process_num > MAX_PROCESS_NUM:
|
|
122
|
+
raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.")
|
|
118
123
|
|
|
119
124
|
|
|
120
125
|
def check_data_type_list(data_type_list):
|
|
121
126
|
if data_type_list is None:
|
|
122
|
-
logger.info(f"data_type_list is None, use
|
|
127
|
+
logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}")
|
|
123
128
|
return
|
|
124
129
|
if not isinstance(data_type_list, list):
|
|
125
130
|
raise ValueError(f"data_type_list({data_type_list}) is not a list")
|
|
@@ -161,4 +166,5 @@ def csv2tensorboard_by_step(
|
|
|
161
166
|
p.start()
|
|
162
167
|
for p in processes:
|
|
163
168
|
p.join()
|
|
169
|
+
recursive_chmod(output_dirpath)
|
|
164
170
|
logger.info(f"output has been saved to: {output_dirpath}")
|
|
@@ -24,6 +24,7 @@ import torch.nn as nn
|
|
|
24
24
|
from msprobe.core.common.const import MonitorConst
|
|
25
25
|
from msprobe.core.common.file_utils import load_yaml
|
|
26
26
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name
|
|
27
|
+
from msprobe.pytorch.common.log import logger
|
|
27
28
|
|
|
28
29
|
try:
|
|
29
30
|
import torch_npu
|
|
@@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", [])
|
|
|
37
38
|
|
|
38
39
|
StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml")
|
|
39
40
|
StackBlackList = load_yaml(StackBlackListPath).get("stack", [])
|
|
41
|
+
MAX_STRING_LENGTH = 1000
|
|
40
42
|
|
|
41
43
|
distributed_func = {}
|
|
42
44
|
for f in dir(dist):
|
|
@@ -139,6 +141,8 @@ def get_process_group(process_group):
|
|
|
139
141
|
|
|
140
142
|
|
|
141
143
|
def stack_filter(stack):
|
|
144
|
+
if len(stack) > MAX_STRING_LENGTH:
|
|
145
|
+
logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
|
|
142
146
|
for pattern in StackBlackList:
|
|
143
147
|
if re.search(pattern, stack):
|
|
144
148
|
return False
|
|
@@ -188,10 +192,12 @@ def update_data(old, new):
|
|
|
188
192
|
|
|
189
193
|
|
|
190
194
|
def is_target_line(codeline):
|
|
191
|
-
stack = get_callstack()
|
|
192
|
-
whole_stack = ';'.join(stack)
|
|
193
195
|
if codeline == []:
|
|
194
196
|
return True
|
|
197
|
+
stack = get_callstack()
|
|
198
|
+
whole_stack = ';'.join(stack)
|
|
199
|
+
if len(whole_stack) > MAX_STRING_LENGTH:
|
|
200
|
+
logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
|
|
195
201
|
for pattern in codeline:
|
|
196
202
|
if re.search(pattern, whole_stack):
|
|
197
203
|
return True
|
|
@@ -26,8 +26,9 @@ from torch.utils.hooks import BackwardHook
|
|
|
26
26
|
|
|
27
27
|
from msprobe.core.common.const import MonitorConst, Const
|
|
28
28
|
from msprobe.core.common.file_utils import load_json, save_json
|
|
29
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
29
30
|
from msprobe.pytorch.common.log import logger
|
|
30
|
-
from msprobe.pytorch.common.utils import is_recomputation
|
|
31
|
+
from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
|
|
31
32
|
from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter
|
|
32
33
|
from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \
|
|
33
34
|
CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
@@ -39,7 +40,7 @@ from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_write
|
|
|
39
40
|
from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec
|
|
40
41
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
41
42
|
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
42
|
-
get_output_base_dir, get_target_output_dir
|
|
43
|
+
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
43
44
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
44
45
|
|
|
45
46
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
@@ -176,7 +177,8 @@ class GradContext:
|
|
|
176
177
|
class TrainerMon:
|
|
177
178
|
tensor_metrics = TensorMetrics()
|
|
178
179
|
|
|
179
|
-
|
|
180
|
+
# 保留原opt_ty参数, 兼容msprobe1.2.2前旧版本
|
|
181
|
+
def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None:
|
|
180
182
|
# TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
|
|
181
183
|
self.config_file_path = config_file_path
|
|
182
184
|
self.process_group = get_process_group(process_group)
|
|
@@ -222,6 +224,7 @@ class TrainerMon:
|
|
|
222
224
|
self.micro_batch_number = 1
|
|
223
225
|
self.optimizer_class = None
|
|
224
226
|
self.optimizer_mon = None
|
|
227
|
+
self.optimizer_trans = None
|
|
225
228
|
|
|
226
229
|
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
227
230
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
@@ -322,8 +325,6 @@ class TrainerMon:
|
|
|
322
325
|
self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
|
|
323
326
|
self.cc_logged_stack = defaultdict(set)
|
|
324
327
|
self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
|
|
325
|
-
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
326
|
-
api_register.redirect_api()
|
|
327
328
|
|
|
328
329
|
self.common_info()
|
|
329
330
|
|
|
@@ -336,11 +337,11 @@ class TrainerMon:
|
|
|
336
337
|
|
|
337
338
|
# 初始化writer, 创建输出目录
|
|
338
339
|
if self.format not in FORMAT_MAPPING:
|
|
339
|
-
logger.
|
|
340
|
+
logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
|
|
340
341
|
self.format = MonitorConst.CSV
|
|
341
342
|
|
|
342
343
|
if self.ur_distribution and self.format != 'tensorboard':
|
|
343
|
-
logger.
|
|
344
|
+
logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution")
|
|
344
345
|
self.ur_distribution = False
|
|
345
346
|
|
|
346
347
|
writer = FORMAT_MAPPING[self.format]
|
|
@@ -363,19 +364,6 @@ class TrainerMon:
|
|
|
363
364
|
self.rank)
|
|
364
365
|
self.anomaly_data_writer.init_detected_json()
|
|
365
366
|
|
|
366
|
-
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
367
|
-
rank = None
|
|
368
|
-
if dist.is_initialized():
|
|
369
|
-
rank = dist.get_rank()
|
|
370
|
-
if (rank not in rank_list) and len(rank_list) != 0:
|
|
371
|
-
return
|
|
372
|
-
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
373
|
-
|
|
374
|
-
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
375
|
-
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
376
|
-
self._register_param_call_id("_hook_module", key)
|
|
377
|
-
return {key: tensor}
|
|
378
|
-
|
|
379
367
|
def common_info(self):
|
|
380
368
|
if not self.xy_distribution:
|
|
381
369
|
logger.info_on_rank_0("> module input/output input_grad/output_grad is not monitored. ")
|
|
@@ -392,94 +380,31 @@ class TrainerMon:
|
|
|
392
380
|
if not self.cc_distribution.get('enable', False):
|
|
393
381
|
logger.info_on_rank_0("> cc operator is not monitored.")
|
|
394
382
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
hooked_count = 0
|
|
406
|
-
for vpp_stage, model_chunk in enumerate(self.model):
|
|
407
|
-
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
408
|
-
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
409
|
-
'targets'].keys()
|
|
410
|
-
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
411
|
-
|
|
412
|
-
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
413
|
-
|
|
414
|
-
def clone_if_tensor(args):
|
|
415
|
-
if isinstance(args, tuple):
|
|
416
|
-
return tuple([clone_if_tensor(arg) for arg in args])
|
|
417
|
-
elif isinstance(args, torch.Tensor):
|
|
418
|
-
return args.clone()
|
|
419
|
-
else:
|
|
420
|
-
return args
|
|
421
|
-
|
|
422
|
-
@torch.no_grad
|
|
423
|
-
def wrap_hook_setup(setup):
|
|
424
|
-
def wrapped_setup(*args, **kwargs):
|
|
425
|
-
args = setup(*args, **kwargs)
|
|
426
|
-
args = clone_if_tensor(args)
|
|
427
|
-
return args
|
|
428
|
-
|
|
429
|
-
return wrapped_setup
|
|
430
|
-
|
|
431
|
-
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
432
|
-
|
|
433
|
-
return
|
|
434
|
-
|
|
435
|
-
def generate_param_metrics(self, opt_context):
|
|
436
|
-
if not self.param_distribution:
|
|
437
|
-
return
|
|
438
|
-
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
439
|
-
|
|
440
|
-
def generate_mv_metrics(self, opt_context):
|
|
441
|
-
if not self.mv_distribution:
|
|
442
|
-
return
|
|
443
|
-
opt_context.exp_avg_metric = {}
|
|
444
|
-
opt_context.exp_avg_sq_metric = {}
|
|
445
|
-
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
446
|
-
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
447
|
-
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
448
|
-
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
449
|
-
|
|
450
|
-
def generate_wgrad_metrics(self):
|
|
451
|
-
if not self.wg_distribution:
|
|
452
|
-
return {}, {}
|
|
453
|
-
|
|
454
|
-
if self.weight_hooked:
|
|
455
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
456
|
-
|
|
457
|
-
grad_dict = {}
|
|
458
|
-
for param, name in self.param2name.items():
|
|
459
|
-
if self.duplicate_param.get(name, False):
|
|
460
|
-
continue
|
|
461
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
462
|
-
if grad is None:
|
|
463
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
464
|
-
continue
|
|
465
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
466
|
-
self._register_param_call_id("hook_optimizer", tag)
|
|
467
|
-
grad_dict[tag] = grad
|
|
383
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
384
|
+
def monitor_gnorm_with_ad(self, model, optimizer=None, grad_acc_steps=1, tp_group=None, dp_group=None,
|
|
385
|
+
start_iteration=0):
|
|
386
|
+
if optimizer is None:
|
|
387
|
+
optimizer = getattr(self, "optimizer_trans", None) # 兼容老版本可传None的情况, 从set_wrapped_optimizer获取
|
|
388
|
+
if optimizer is None:
|
|
389
|
+
logger.error("monitor_gnorm_with_ad: please set_wrapped_optimizer before it or input optimizer!=None")
|
|
390
|
+
return
|
|
391
|
+
self.set_monitor(model, optimizer, grad_acc_steps, tp_group, dp_group, start_iteration)
|
|
468
392
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
393
|
+
# 保留原接口, 兼容msprobe1.2.2前旧版本
|
|
394
|
+
def set_wrapped_optimizer(self, optimizer):
|
|
395
|
+
self.optimizer_trans = optimizer
|
|
472
396
|
|
|
473
397
|
def set_monitor(
|
|
474
398
|
self,
|
|
475
399
|
model,
|
|
400
|
+
optimizer,
|
|
476
401
|
grad_acc_steps=1,
|
|
477
|
-
optimizer=None,
|
|
478
402
|
tp_group=None,
|
|
479
403
|
dp_group=None,
|
|
480
404
|
start_iteration=0
|
|
481
405
|
):
|
|
482
406
|
"""External interface"""
|
|
407
|
+
grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration)
|
|
483
408
|
global start_step
|
|
484
409
|
start_step = start_iteration
|
|
485
410
|
logger.info(f'grad acc steps {grad_acc_steps}')
|
|
@@ -502,8 +427,24 @@ class TrainerMon:
|
|
|
502
427
|
self.hook_optimizer(optimizer)
|
|
503
428
|
self._patch_grad_sync()
|
|
504
429
|
self.hook_modules()
|
|
430
|
+
if self.cc_distribution.get('enable', False):
|
|
431
|
+
self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
|
|
432
|
+
api_register.redirect_api()
|
|
505
433
|
self.monitoring = True
|
|
506
434
|
|
|
435
|
+
def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list):
|
|
436
|
+
rank = None
|
|
437
|
+
if dist.is_initialized():
|
|
438
|
+
rank = dist.get_rank()
|
|
439
|
+
if (rank not in rank_list) and len(rank_list) != 0:
|
|
440
|
+
return
|
|
441
|
+
self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank)
|
|
442
|
+
|
|
443
|
+
def build_tbtag_tensor_map(self, module_name, tag, tensor):
|
|
444
|
+
key = get_summary_writer_tag_name(module_name, tag, self.rank)
|
|
445
|
+
self._register_param_call_id("_hook_module", key)
|
|
446
|
+
return {key: tensor}
|
|
447
|
+
|
|
507
448
|
def generate_param_map(self, tag, param_tensor):
|
|
508
449
|
metrics = {}
|
|
509
450
|
for name in self.param2name.values():
|
|
@@ -514,6 +455,44 @@ class TrainerMon:
|
|
|
514
455
|
metrics[key] = param_tensor[name]
|
|
515
456
|
return metrics
|
|
516
457
|
|
|
458
|
+
def generate_param_metrics(self, opt_context):
|
|
459
|
+
if not self.param_distribution:
|
|
460
|
+
return
|
|
461
|
+
get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric)
|
|
462
|
+
|
|
463
|
+
def generate_mv_metrics(self, opt_context):
|
|
464
|
+
if not self.mv_distribution:
|
|
465
|
+
return
|
|
466
|
+
opt_context.exp_avg_metric = {}
|
|
467
|
+
opt_context.exp_avg_sq_metric = {}
|
|
468
|
+
m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg)
|
|
469
|
+
v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq)
|
|
470
|
+
get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric)
|
|
471
|
+
get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric)
|
|
472
|
+
|
|
473
|
+
def generate_wgrad_metrics(self):
|
|
474
|
+
if not self.wg_distribution:
|
|
475
|
+
return {}, {}
|
|
476
|
+
|
|
477
|
+
if self.weight_hooked:
|
|
478
|
+
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
479
|
+
|
|
480
|
+
grad_dict = {}
|
|
481
|
+
for param, name in self.param2name.items():
|
|
482
|
+
if self.duplicate_param.get(name, False):
|
|
483
|
+
continue
|
|
484
|
+
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
485
|
+
if grad is None:
|
|
486
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
487
|
+
continue
|
|
488
|
+
tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
489
|
+
self._register_param_call_id("hook_optimizer", tag)
|
|
490
|
+
grad_dict[tag] = grad
|
|
491
|
+
|
|
492
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
493
|
+
unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre
|
|
494
|
+
return self.grad_context.post, unreduced_grad
|
|
495
|
+
|
|
517
496
|
def generate_xy_metrics(self):
|
|
518
497
|
actv = {}
|
|
519
498
|
for fwd_context in self.module_fwd_hook_context_by_module.values():
|
|
@@ -557,9 +536,9 @@ class TrainerMon:
|
|
|
557
536
|
def write_mv_tb(self, opt_context):
|
|
558
537
|
if not self.mv_distribution:
|
|
559
538
|
return
|
|
560
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
539
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric,
|
|
561
540
|
opt_context.step, MonitorConst.EXP_AVG)
|
|
562
|
-
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
541
|
+
self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric,
|
|
563
542
|
opt_context.step, MonitorConst.EXP_AVG_SQ)
|
|
564
543
|
|
|
565
544
|
def write_grad_tb(self, step):
|
|
@@ -572,7 +551,7 @@ class TrainerMon:
|
|
|
572
551
|
self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
|
|
573
552
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
574
553
|
|
|
575
|
-
def hook_optimizer(self, optimizer
|
|
554
|
+
def hook_optimizer(self, optimizer):
|
|
576
555
|
# in DDP by default use params_have_main_grad
|
|
577
556
|
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
578
557
|
context = self.optimizer_context[optimizer]
|
|
@@ -638,7 +617,6 @@ class TrainerMon:
|
|
|
638
617
|
optimizer_pre_step_hook(optimizer, args, kwargs)
|
|
639
618
|
out = func(*args, **kwargs)
|
|
640
619
|
return out
|
|
641
|
-
|
|
642
620
|
return wrapper
|
|
643
621
|
|
|
644
622
|
if self.optimizer_hooked:
|
|
@@ -674,6 +652,7 @@ class TrainerMon:
|
|
|
674
652
|
validate_config(config)
|
|
675
653
|
self.config = config
|
|
676
654
|
self.set_config()
|
|
655
|
+
self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
|
|
677
656
|
logger.warning(f"config is updated at step{context.step - 1}, "
|
|
678
657
|
f"will start new hook at step{context.step}.")
|
|
679
658
|
except Exception as e:
|
|
@@ -721,6 +700,9 @@ class TrainerMon:
|
|
|
721
700
|
if self.anomaly_data_factory:
|
|
722
701
|
self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies())
|
|
723
702
|
self.summary_writer.clear_anomalies()
|
|
703
|
+
|
|
704
|
+
if self.format == MonitorConst.TENSORBOARD:
|
|
705
|
+
chmod_tensorboard_dir(self.tensorboard_dir)
|
|
724
706
|
self.call_id = 0
|
|
725
707
|
self.param_name_call_id.clear()
|
|
726
708
|
|
|
@@ -739,7 +721,47 @@ class TrainerMon:
|
|
|
739
721
|
|
|
740
722
|
optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer)
|
|
741
723
|
self.origin_step_func = optimizer.__class__.step
|
|
724
|
+
return
|
|
725
|
+
|
|
726
|
+
def hook_modules(self):
|
|
727
|
+
if self.module_rank_list and (self.rank not in self.module_rank_list):
|
|
728
|
+
return
|
|
729
|
+
|
|
730
|
+
targets = self.config['targets']
|
|
731
|
+
module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key]
|
|
732
|
+
for key in module_in_all_stage:
|
|
733
|
+
struct = targets.pop(key)
|
|
734
|
+
targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
|
|
735
|
+
|
|
736
|
+
hooked_count = 0
|
|
737
|
+
for vpp_stage, model_chunk in enumerate(self.model):
|
|
738
|
+
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
739
|
+
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
740
|
+
'targets'].keys()
|
|
741
|
+
hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
|
|
742
|
+
|
|
743
|
+
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
744
|
+
|
|
745
|
+
@recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor')
|
|
746
|
+
def clone_if_tensor(args):
|
|
747
|
+
if isinstance(args, tuple):
|
|
748
|
+
return tuple([clone_if_tensor(arg) for arg in args])
|
|
749
|
+
elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
|
|
750
|
+
return args.clone()
|
|
751
|
+
else:
|
|
752
|
+
return args
|
|
742
753
|
|
|
754
|
+
@torch.no_grad
|
|
755
|
+
def wrap_hook_setup(setup):
|
|
756
|
+
def wrapped_setup(*args, **kwargs):
|
|
757
|
+
args = setup(*args, **kwargs)
|
|
758
|
+
args = clone_if_tensor(args)
|
|
759
|
+
return args
|
|
760
|
+
|
|
761
|
+
return wrapped_setup
|
|
762
|
+
|
|
763
|
+
BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook)
|
|
764
|
+
BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook)
|
|
743
765
|
return
|
|
744
766
|
|
|
745
767
|
def _remove_all_hooks(self, optimizer):
|
|
@@ -783,6 +805,7 @@ class TrainerMon:
|
|
|
783
805
|
for handle in self.handles['cc']:
|
|
784
806
|
handle.remove()
|
|
785
807
|
self.handles['cc'].clear()
|
|
808
|
+
api_register.restore_api()
|
|
786
809
|
for _, context in self.cc_context.items():
|
|
787
810
|
context.reset()
|
|
788
811
|
|
|
@@ -956,7 +979,7 @@ class TrainerMon:
|
|
|
956
979
|
return
|
|
957
980
|
if not context.verified:
|
|
958
981
|
context.focused_in_col = validate_config_spec(
|
|
959
|
-
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
982
|
+
context.format_by_arg[MonitorConst.INPUT_GRAD],
|
|
960
983
|
input_grad, context.module_name, MonitorConst.INPUT_GRAD)
|
|
961
984
|
context.focused_out_col = validate_config_spec(
|
|
962
985
|
context.format_by_arg[MonitorConst.OUTPUT_GRAD],
|
|
@@ -1052,7 +1075,7 @@ class TrainerMon:
|
|
|
1052
1075
|
self.enable_megatron = True
|
|
1053
1076
|
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1054
1077
|
except ImportError:
|
|
1055
|
-
self.enable_megatron = False
|
|
1078
|
+
self.enable_megatron = False | self.enable_megatron
|
|
1056
1079
|
|
|
1057
1080
|
if not self.enable_megatron:
|
|
1058
1081
|
self._hook_weights()
|
|
@@ -1067,9 +1090,12 @@ class TrainerMon:
|
|
|
1067
1090
|
if param.micro_step == self.micro_batch_number:
|
|
1068
1091
|
param.micro_step = 0
|
|
1069
1092
|
if self.params_have_main_grad:
|
|
1070
|
-
|
|
1093
|
+
grad = param.main_grad
|
|
1071
1094
|
else:
|
|
1072
|
-
|
|
1095
|
+
grad = param.grad
|
|
1096
|
+
if is_float8_tensor(grad):
|
|
1097
|
+
grad = grad.float()
|
|
1098
|
+
context_dict[key] = grad.clone()
|
|
1073
1099
|
|
|
1074
1100
|
logger.info("hooking weights.")
|
|
1075
1101
|
for param, name in self.param2name.items():
|
|
@@ -16,6 +16,7 @@ import re
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
19
20
|
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
21
|
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
21
22
|
|
|
@@ -166,6 +167,8 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
166
167
|
# Non-tensor in/output filled with nan.
|
|
167
168
|
out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
|
|
168
169
|
continue
|
|
170
|
+
if is_float8_tensor(tensor):
|
|
171
|
+
tensor = tensor.float()
|
|
169
172
|
for metric_name in ops:
|
|
170
173
|
fun_metric = config_metric_registry.get(metric_name)
|
|
171
174
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
@@ -185,7 +185,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
|
185
185
|
for opt in torch_opt.chained_optimizers:
|
|
186
186
|
self.map_fp16_tp_fp32_param(opt)
|
|
187
187
|
|
|
188
|
-
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
188
|
+
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
189
189
|
torch_opt.state = {}
|
|
190
190
|
for opt in torch_opt.chained_optimizers:
|
|
191
191
|
torch_opt.state.update(opt.optimizer.state)
|
|
@@ -198,7 +198,7 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
|
198
198
|
for opt in torch_opt.chained_optimizers:
|
|
199
199
|
self.map_fp16_tp_fp32_param(opt)
|
|
200
200
|
|
|
201
|
-
if not isinstance(torch_opt, torch.optim.Optimizer):
|
|
201
|
+
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
202
202
|
torch_opt.state = {}
|
|
203
203
|
for opt in torch_opt.chained_optimizers:
|
|
204
204
|
torch_opt.state.update(opt.optimizer.state)
|
|
@@ -206,9 +206,60 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
|
206
206
|
|
|
207
207
|
|
|
208
208
|
class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
|
|
209
|
-
def
|
|
210
|
-
|
|
209
|
+
def get_group_index(self, torch_opt):
|
|
210
|
+
bit16_groups = torch_opt.bf16_groups
|
|
211
|
+
param2group = defaultdict()
|
|
212
|
+
for group_idx, bit16_group in enumerate(bit16_groups):
|
|
213
|
+
for param in bit16_group:
|
|
214
|
+
param2group[param] = group_idx
|
|
215
|
+
return param2group
|
|
216
|
+
|
|
217
|
+
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
218
|
+
param2group = self.get_group_index(torch_opt)
|
|
219
|
+
exp_avg_dict = defaultdict(float)
|
|
220
|
+
exp_avg_sq_dict = defaultdict(float)
|
|
221
|
+
update_dict = defaultdict()
|
|
222
|
+
ratio_dict = defaultdict()
|
|
223
|
+
|
|
224
|
+
param_slice_mappings = torch_opt.state_dict()['param_slice_mappings']
|
|
225
|
+
for param, name in params2name.items():
|
|
226
|
+
group_idx = param2group[param]
|
|
227
|
+
state = torch_opt.optimizer.state[torch_opt.fp32_groups_flat_partition[group_idx]]
|
|
228
|
+
if state.get('exp_avg', None) is None:
|
|
229
|
+
logger.warning(f"optimizer state is None. Something is wrong if this is not the first step")
|
|
230
|
+
break
|
|
231
|
+
param_slice_mapping = param_slice_mappings[group_idx]
|
|
232
|
+
hp_address = param_slice_mapping.get(torch_opt.param_names[param])
|
|
233
|
+
if hp_address is None:
|
|
234
|
+
continue
|
|
235
|
+
start = hp_address.start
|
|
236
|
+
numel = hp_address.numel
|
|
211
237
|
|
|
238
|
+
if monitor.mv_distribution:
|
|
239
|
+
exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel)
|
|
240
|
+
exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel)
|
|
241
|
+
if monitor.mg_direction:
|
|
242
|
+
exp_avg_dict[name] = state['exp'].narrow(0, start, numel)
|
|
243
|
+
if monitor.ur_distribution:
|
|
244
|
+
if len(torch_opt.param_groups) > 1:
|
|
245
|
+
logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
|
|
246
|
+
if 'step' in state:
|
|
247
|
+
step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
248
|
+
elif 'step' in torch_opt.param_groups[0]:
|
|
249
|
+
step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
250
|
+
else:
|
|
251
|
+
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
252
|
+
continue
|
|
253
|
+
exp_avg = state['exp_avg'].narrow(0, start, numel)
|
|
254
|
+
exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
|
|
255
|
+
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
256
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
257
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
258
|
+
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
259
|
+
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
260
|
+
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
261
|
+
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
262
|
+
|
|
212
263
|
|
|
213
264
|
class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
214
265
|
def get_param_index(self, params2name, name2index, torch_opt):
|
|
@@ -92,7 +92,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
|
|
|
92
92
|
if errors:
|
|
93
93
|
logger.info(errors)
|
|
94
94
|
else:
|
|
95
|
-
logger.info(f'grad mean is in consist between unreduced grad and reduced grad
|
|
95
|
+
logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitored.')
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
def assert_equal(a, b):
|