mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -15,18 +15,21 @@
|
|
|
15
15
|
import json
|
|
16
16
|
import os
|
|
17
17
|
import uuid
|
|
18
|
+
import importlib
|
|
18
19
|
from collections import defaultdict
|
|
19
20
|
from datetime import datetime
|
|
20
21
|
from functools import partial
|
|
22
|
+
from itertools import cycle
|
|
21
23
|
|
|
22
24
|
import pytz
|
|
23
25
|
import torch
|
|
24
26
|
import torch.distributed as dist
|
|
25
27
|
import pandas as pd
|
|
26
28
|
from torch.utils.hooks import BackwardHook
|
|
29
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
27
30
|
|
|
28
31
|
from msprobe.core.common.const import MonitorConst, Const
|
|
29
|
-
from msprobe.core.common.file_utils import load_json, save_json
|
|
32
|
+
from msprobe.core.common.file_utils import load_json, save_json, make_dir
|
|
30
33
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
31
34
|
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
35
|
from msprobe.core.common.file_utils import write_df_to_csv
|
|
@@ -39,9 +42,9 @@ from msprobe.pytorch.monitor.utils import get_param_struct
|
|
|
39
42
|
from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
40
43
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
41
44
|
get_process_group
|
|
42
|
-
from msprobe.pytorch.monitor.features import get_sign_matches
|
|
45
|
+
from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt
|
|
43
46
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
44
|
-
TensorMetrics, squash_param_name
|
|
47
|
+
TensorMetrics, squash_param_name, get_entropy_metric, get_sr_metric
|
|
45
48
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
46
49
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
47
50
|
|
|
@@ -56,6 +59,7 @@ FORMAT_MAPPING = {
|
|
|
56
59
|
MonitorConst.CSV: CSVWriterWithAD,
|
|
57
60
|
MonitorConst.API: BaseWriterWithAD
|
|
58
61
|
}
|
|
62
|
+
start_step = 0
|
|
59
63
|
|
|
60
64
|
|
|
61
65
|
def param_is_not_tensor_parallel_duplicate(param, tp_group):
|
|
@@ -82,7 +86,17 @@ class ModuleHookContext:
|
|
|
82
86
|
self.actvgrad.clear()
|
|
83
87
|
|
|
84
88
|
|
|
85
|
-
|
|
89
|
+
class FeatureHookContext:
|
|
90
|
+
def __init__(self, module_name):
|
|
91
|
+
self.step = 0
|
|
92
|
+
self.micro_step = 0
|
|
93
|
+
self.attention_feature = {}
|
|
94
|
+
self.linear_feature = {}
|
|
95
|
+
self.module_name = module_name
|
|
96
|
+
|
|
97
|
+
def reset(self):
|
|
98
|
+
self.attention_feature.clear()
|
|
99
|
+
self.linear_feature.clear()
|
|
86
100
|
|
|
87
101
|
|
|
88
102
|
class OptimizerContext:
|
|
@@ -159,8 +173,8 @@ class TrainerMon:
|
|
|
159
173
|
self.params_have_main_grad = params_have_main_grad
|
|
160
174
|
self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
161
175
|
self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer)
|
|
162
|
-
self.origin_start_grad_sync = None
|
|
163
176
|
self.fsdp_post_backward_hook = None
|
|
177
|
+
self.fsdp2_foreach_reduce = None
|
|
164
178
|
self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
|
|
165
179
|
self.config = load_json(config_file_path)
|
|
166
180
|
validate_config(self.config)
|
|
@@ -195,7 +209,9 @@ class TrainerMon:
|
|
|
195
209
|
self.dp_group = None
|
|
196
210
|
self.tp_group = None
|
|
197
211
|
self.enable_megatron = False
|
|
212
|
+
self.enable_deepspeed = False
|
|
198
213
|
self.fsdp_wrapped_module = False
|
|
214
|
+
self.fsdp2_wrapped_module = False
|
|
199
215
|
self.micro_batch_number = 1
|
|
200
216
|
self.optimizer_mon = None
|
|
201
217
|
self.optimizer_trans = None
|
|
@@ -203,6 +219,7 @@ class TrainerMon:
|
|
|
203
219
|
# TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
|
|
204
220
|
self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
205
221
|
self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
|
|
222
|
+
self.feature_hook_context_by_module = defaultdict(FeatureHookContext)
|
|
206
223
|
self.optimizer_context = defaultdict(OptimizerContext)
|
|
207
224
|
self.cc_context = defaultdict(CommunicationContext)
|
|
208
225
|
self.grad_context = GradContext()
|
|
@@ -210,9 +227,12 @@ class TrainerMon:
|
|
|
210
227
|
self.param2name = defaultdict(str)
|
|
211
228
|
self.name2indices = defaultdict()
|
|
212
229
|
self.name2param = {}
|
|
230
|
+
self.origin2squash = {}
|
|
213
231
|
self.duplicate_param = {}
|
|
214
232
|
self.name2tag = {}
|
|
215
233
|
self.param_name_call_id = {}
|
|
234
|
+
self.flat_prefix_names = []
|
|
235
|
+
self.flat_prefix_reverse_iter = None
|
|
216
236
|
self.call_id = 0
|
|
217
237
|
self.module_struct = defaultdict(dict)
|
|
218
238
|
self.grad_accs = []
|
|
@@ -270,6 +290,18 @@ class TrainerMon:
|
|
|
270
290
|
cc_tensor.reset()
|
|
271
291
|
return metrics
|
|
272
292
|
|
|
293
|
+
@staticmethod
|
|
294
|
+
def get_linear_hook_target(module):
|
|
295
|
+
if isinstance(module, torch.nn.Embedding):
|
|
296
|
+
return ''
|
|
297
|
+
if hasattr(module, "num_embeddings") or hasattr(module, "vocab_start_index"):
|
|
298
|
+
return ''
|
|
299
|
+
for weight_name in ["weight", "wg"]:
|
|
300
|
+
if hasattr(module, weight_name) and isinstance(getattr(module, weight_name), torch.Tensor):
|
|
301
|
+
if getattr(module, weight_name).dim() == 2:
|
|
302
|
+
return weight_name
|
|
303
|
+
return ''
|
|
304
|
+
|
|
273
305
|
def set_config(self):
|
|
274
306
|
logger.info(f"current config: {self.config}")
|
|
275
307
|
self.start_step = self.config.get("start_step", 0)
|
|
@@ -294,6 +326,8 @@ class TrainerMon:
|
|
|
294
326
|
self.cc_distribution = self.config.get("cc_distribution", {})
|
|
295
327
|
self.stack_info = self.config.get('stack_info', False)
|
|
296
328
|
self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False)
|
|
329
|
+
self.recording_l2_features = self.config.get("recording_l2_features", False)
|
|
330
|
+
self.sa_order = self.config.get("sa_order", "s,b,h,d")
|
|
297
331
|
|
|
298
332
|
if not self.cc_distribution.get('enable', False):
|
|
299
333
|
self.cc_log_only = False
|
|
@@ -352,6 +386,8 @@ class TrainerMon:
|
|
|
352
386
|
logger.info_on_rank_0("> momentum and variance of adam is not monitored. ")
|
|
353
387
|
if not self.wg_distribution:
|
|
354
388
|
logger.info_on_rank_0("> weight grad of specified module is not monitored. ")
|
|
389
|
+
if not self.recording_l2_features:
|
|
390
|
+
logger.info_on_rank_0("> l2 features of specified module is not monitored. ")
|
|
355
391
|
if not self.mg_direction:
|
|
356
392
|
logger.info_on_rank_0('> grad and momentum direction will not be compared.')
|
|
357
393
|
if not self.cc_distribution.get('enable', False):
|
|
@@ -533,6 +569,24 @@ class TrainerMon:
|
|
|
533
569
|
if self.grad_context.actv:
|
|
534
570
|
self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD)
|
|
535
571
|
|
|
572
|
+
def write_metrics_if_not_empty(self, features, metrics, step, hook_name):
|
|
573
|
+
if not features or len(features) == 0:
|
|
574
|
+
return
|
|
575
|
+
use_micro_step = hook_name not in ["linear_hook"]
|
|
576
|
+
self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=use_micro_step)
|
|
577
|
+
features.clear()
|
|
578
|
+
|
|
579
|
+
def write_features_tb(self, step):
|
|
580
|
+
if not self.recording_l2_features:
|
|
581
|
+
return
|
|
582
|
+
for context in self.feature_hook_context_by_module.values():
|
|
583
|
+
num_features = len(context.attention_feature) + len(context.linear_feature)
|
|
584
|
+
if num_features == 0:
|
|
585
|
+
continue
|
|
586
|
+
self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"],
|
|
587
|
+
step, "attention_hook")
|
|
588
|
+
self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook")
|
|
589
|
+
|
|
536
590
|
def write_param_tb(self, opt_context):
|
|
537
591
|
if not self.param_distribution:
|
|
538
592
|
return
|
|
@@ -687,6 +741,7 @@ class TrainerMon:
|
|
|
687
741
|
if self.anomaly_data_factory:
|
|
688
742
|
self.anomaly_data_factory.set_call_id(self.param_name_call_id)
|
|
689
743
|
self.write_xy_tb(context.step)
|
|
744
|
+
self.write_features_tb(context.step)
|
|
690
745
|
self.write_grad_tb(context.step)
|
|
691
746
|
self.write_mv_tb(context)
|
|
692
747
|
self.write_param_tb(context)
|
|
@@ -756,7 +811,8 @@ class TrainerMon:
|
|
|
756
811
|
vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
|
|
757
812
|
targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[
|
|
758
813
|
'targets'].keys()
|
|
759
|
-
|
|
814
|
+
l2_target_names = self.config.get('l2_targets', '')
|
|
815
|
+
hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage)
|
|
760
816
|
|
|
761
817
|
logger.info_on_rank_0(f"> {hooked_count} modules are monitored.")
|
|
762
818
|
|
|
@@ -797,6 +853,9 @@ class TrainerMon:
|
|
|
797
853
|
for handle in self.handles['xy']:
|
|
798
854
|
handle.remove()
|
|
799
855
|
self.handles['xy'].clear()
|
|
856
|
+
for handle in self.handles['L2_features']:
|
|
857
|
+
handle.remove()
|
|
858
|
+
self.handles['L2_features'].clear()
|
|
800
859
|
# 清空对应context缓存
|
|
801
860
|
for _, fwd_context in self.module_fwd_hook_context_by_module.items():
|
|
802
861
|
fwd_context.reset()
|
|
@@ -804,22 +863,14 @@ class TrainerMon:
|
|
|
804
863
|
bwd_context.reset()
|
|
805
864
|
self.grad_context.reset() # 权重梯度和激活值梯度都在这
|
|
806
865
|
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
810
|
-
Bucket.start_grad_sync = self.origin_start_grad_sync
|
|
811
|
-
logger.info("remove Bucket start_grad_sync")
|
|
812
|
-
except ImportError:
|
|
813
|
-
pass
|
|
814
|
-
try:
|
|
815
|
-
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
816
|
-
_ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync
|
|
817
|
-
logger.info("remove _ParamAndGradBucketGroup start_grad_sync")
|
|
818
|
-
except ImportError:
|
|
819
|
-
pass
|
|
820
|
-
elif self.fsdp_post_backward_hook: # fsdp
|
|
866
|
+
self.optimizer_mon.restore_grad_sync(self)
|
|
867
|
+
if self.fsdp_post_backward_hook: # fsdp
|
|
821
868
|
torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook
|
|
822
869
|
logger.info("remove patch_post_backward_hook in fsdp.")
|
|
870
|
+
if self.fsdp2_foreach_reduce: # fsdp2
|
|
871
|
+
torch.distributed.fsdp._fully_shard._fsdp_collectives.foreach_reduce = self.fsdp2_foreach_reduce
|
|
872
|
+
importlib.reload(torch.distributed.fsdp._fully_shard._fsdp_param_group)
|
|
873
|
+
logger.info("remove patch_foreach_reduce_hook in fsdp2.")
|
|
823
874
|
else: # not megatron and not fsdp
|
|
824
875
|
for handle in self.handles['wgrads']:
|
|
825
876
|
handle.remove()
|
|
@@ -881,14 +932,11 @@ class TrainerMon:
|
|
|
881
932
|
logger.info(msg)
|
|
882
933
|
|
|
883
934
|
def _save_module_struct(self):
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json'))
|
|
890
|
-
save_json(module_struct_file, self.module_struct, indent=2)
|
|
891
|
-
logger.info(f"> save module struct to {module_struct_file}")
|
|
935
|
+
output_dir = os.path.join(get_output_base_dir(), 'module_struct', f'rank{self.rank}')
|
|
936
|
+
make_dir(output_dir)
|
|
937
|
+
module_struct_file = os.path.realpath(os.path.join(output_dir, 'module_struct.json'))
|
|
938
|
+
save_json(module_struct_file, self.module_struct, indent=2)
|
|
939
|
+
logger.info(f"> save module struct to {module_struct_file}")
|
|
892
940
|
self.struct_printed = True
|
|
893
941
|
|
|
894
942
|
def _is_target_param(self, param_name, param, prefix):
|
|
@@ -896,23 +944,32 @@ class TrainerMon:
|
|
|
896
944
|
squash_name = prefix + squash_param_name(param_name, self.squash_name)
|
|
897
945
|
for target in self.config['targets'].keys():
|
|
898
946
|
if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
|
|
899
|
-
setattr(param, "zero_out_wgrad", True)
|
|
900
947
|
return True
|
|
901
948
|
|
|
902
949
|
return False
|
|
903
950
|
|
|
904
951
|
def _register_chunk(self, model_chunk, prefix):
|
|
952
|
+
if isinstance(model_chunk, FSDP):
|
|
953
|
+
if not model_chunk._use_orig_params:
|
|
954
|
+
raise ValueError("Only Support fsdp1 with use_orig_params=True")
|
|
955
|
+
self.fsdp_wrapped_module = True
|
|
905
956
|
for (param_name, param) in model_chunk.named_parameters():
|
|
906
957
|
if not param.requires_grad:
|
|
907
958
|
continue
|
|
908
|
-
if not self.
|
|
909
|
-
self.
|
|
959
|
+
if not self.fsdp2_wrapped_module and param.__class__.__name__ == "DTensor":
|
|
960
|
+
self.fsdp2_wrapped_module = True
|
|
961
|
+
if self.fsdp_wrapped_module: # FSDP1需要记录完整的不被target限制的flat权重前缀名,以供后续对flat解包
|
|
962
|
+
flat_prefix_name, _ = param_name.rsplit(MonitorConst.FSDP_FLAT_SEP, 1)
|
|
963
|
+
if flat_prefix_name not in self.flat_prefix_names:
|
|
964
|
+
self.flat_prefix_names.append(flat_prefix_name)
|
|
965
|
+
|
|
910
966
|
if self._is_target_param(param_name, param, prefix):
|
|
911
967
|
name = prefix + squash_param_name(param_name, self.squash_name)
|
|
912
968
|
if name in self.param2name.values():
|
|
913
969
|
name = prefix + param_name
|
|
914
970
|
self.param2name[param] = name
|
|
915
971
|
self.name2param[name] = param
|
|
972
|
+
self.origin2squash[param_name] = name
|
|
916
973
|
|
|
917
974
|
if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
|
|
918
975
|
self.duplicate_param[name] = True
|
|
@@ -929,6 +986,8 @@ class TrainerMon:
|
|
|
929
986
|
k: get_summary_writer_tag_name(name, k, self.rank)
|
|
930
987
|
for k in keywords
|
|
931
988
|
}
|
|
989
|
+
if self.fsdp_wrapped_module:
|
|
990
|
+
self.flat_prefix_reverse_iter = cycle(reversed(self.flat_prefix_names)) # post_backward_hook调用顺序是反向的
|
|
932
991
|
|
|
933
992
|
def _register_param_name(self):
|
|
934
993
|
for vpp_stage, model_chunk in enumerate(self.model):
|
|
@@ -946,7 +1005,20 @@ class TrainerMon:
|
|
|
946
1005
|
return pattern
|
|
947
1006
|
return ""
|
|
948
1007
|
|
|
949
|
-
def
|
|
1008
|
+
def _is_recording_module(self, module_name, l2_targets, vpp_stage, hook_name):
|
|
1009
|
+
|
|
1010
|
+
if len(l2_targets) > 0:
|
|
1011
|
+
for pattern in [
|
|
1012
|
+
vpp_stage + squash_param_name(module_name, self.squash_name),
|
|
1013
|
+
vpp_stage + module_name,
|
|
1014
|
+
]:
|
|
1015
|
+
if pattern in l2_targets:
|
|
1016
|
+
return pattern
|
|
1017
|
+
elif hook_name in ["linear_hook"]:
|
|
1018
|
+
return vpp_stage + squash_param_name(module_name, self.squash_name)
|
|
1019
|
+
return ""
|
|
1020
|
+
|
|
1021
|
+
def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''):
|
|
950
1022
|
if '_modules' not in module.__dict__:
|
|
951
1023
|
# nothing to hook
|
|
952
1024
|
return 0
|
|
@@ -1025,6 +1097,61 @@ class TrainerMon:
|
|
|
1025
1097
|
context.micro_step = 0
|
|
1026
1098
|
return
|
|
1027
1099
|
|
|
1100
|
+
def extract_attention_feature_hook(module, module_input, module_output, name):
|
|
1101
|
+
if is_recomputation() or not module.training:
|
|
1102
|
+
return
|
|
1103
|
+
|
|
1104
|
+
if module not in self.feature_hook_context_by_module:
|
|
1105
|
+
self.feature_hook_context_by_module[module] = FeatureHookContext(name)
|
|
1106
|
+
context: FeatureHookContext = self.feature_hook_context_by_module[module]
|
|
1107
|
+
tbtag_tensor_map = {}
|
|
1108
|
+
if len(module_input) < 2:
|
|
1109
|
+
logger.warning(
|
|
1110
|
+
f"Length of module_input in attention hook ({name}) is {len(module_input)}, "
|
|
1111
|
+
"expected >= 2. Skipping feature extraction for this module."
|
|
1112
|
+
)
|
|
1113
|
+
return
|
|
1114
|
+
q_h = module_input[0]
|
|
1115
|
+
k_h = module_input[1]
|
|
1116
|
+
qkt = cal_qkt(q_h, k_h, order=self.sa_order)
|
|
1117
|
+
tbtag_tensor_map.update(
|
|
1118
|
+
self.build_tbtag_tensor_map(f'{context.module_name}.attention',
|
|
1119
|
+
f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt)
|
|
1120
|
+
)
|
|
1121
|
+
get_entropy_metric(tbtag_tensor_map, context.attention_feature)
|
|
1122
|
+
|
|
1123
|
+
context.micro_step += 1
|
|
1124
|
+
if context.micro_step == self.micro_batch_number:
|
|
1125
|
+
context.micro_step = 0
|
|
1126
|
+
context.step += 1
|
|
1127
|
+
return
|
|
1128
|
+
|
|
1129
|
+
def extract_linear_sr_hook(module, module_input, module_output, name):
|
|
1130
|
+
if is_recomputation() or not module.training:
|
|
1131
|
+
return
|
|
1132
|
+
weight_name = self.get_linear_hook_target(module)
|
|
1133
|
+
if weight_name == '':
|
|
1134
|
+
return
|
|
1135
|
+
|
|
1136
|
+
if module not in self.feature_hook_context_by_module:
|
|
1137
|
+
self.feature_hook_context_by_module[module] = FeatureHookContext(name)
|
|
1138
|
+
context: FeatureHookContext = self.feature_hook_context_by_module[module]
|
|
1139
|
+
|
|
1140
|
+
if context.micro_step == (self.micro_batch_number - 1):
|
|
1141
|
+
tbtag_tensor_map = {}
|
|
1142
|
+
value = getattr(module, weight_name).data
|
|
1143
|
+
tbtag_tensor_map.update(
|
|
1144
|
+
self.build_tbtag_tensor_map(f'{context.module_name}.linear',
|
|
1145
|
+
'', 'sr', value)
|
|
1146
|
+
)
|
|
1147
|
+
get_sr_metric(tbtag_tensor_map, context.linear_feature)
|
|
1148
|
+
|
|
1149
|
+
context.micro_step += 1
|
|
1150
|
+
if context.micro_step == self.micro_batch_number:
|
|
1151
|
+
context.micro_step = 0
|
|
1152
|
+
context.step += 1
|
|
1153
|
+
return
|
|
1154
|
+
|
|
1028
1155
|
def stack_hook(module, args, kwargs, module_output, name):
|
|
1029
1156
|
if module not in self.module_fwd_hook_context_by_module:
|
|
1030
1157
|
self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
|
|
@@ -1056,34 +1183,29 @@ class TrainerMon:
|
|
|
1056
1183
|
self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
|
|
1057
1184
|
logger.info_on_rank_0(f"> {name} is monitored successfully")
|
|
1058
1185
|
hooked_count += 1
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
# In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
|
|
1068
|
-
bucket_params_id_list = [id(params) for params in bucket.params]
|
|
1069
|
-
for param, name in self.param2name.items():
|
|
1070
|
-
if id(param) not in bucket_params_id_list:
|
|
1071
|
-
continue
|
|
1072
|
-
grad = param.main_grad if self.params_have_main_grad else param.grad
|
|
1073
|
-
if grad is None:
|
|
1074
|
-
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
1075
|
-
continue
|
|
1076
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1077
|
-
if tag is None:
|
|
1186
|
+
if not self.print_struct and self.recording_l2_features:
|
|
1187
|
+
for module_name, submodule in module.named_modules():
|
|
1188
|
+
func_map = {
|
|
1189
|
+
"attention_hook": extract_attention_feature_hook,
|
|
1190
|
+
"linear_hook": extract_linear_sr_hook,
|
|
1191
|
+
}
|
|
1192
|
+
for hook_name in func_map.keys():
|
|
1193
|
+
if hook_name not in l2_target_names:
|
|
1078
1194
|
continue
|
|
1079
|
-
|
|
1080
|
-
self.
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1195
|
+
temp_names = l2_target_names[hook_name]
|
|
1196
|
+
name = self._is_recording_module(module_name, temp_names, vpp_stage, hook_name)
|
|
1197
|
+
if name:
|
|
1198
|
+
handle = submodule.register_forward_hook(partial(func_map[hook_name], name=name))
|
|
1199
|
+
print_feature_name = hook_name.split('_')[0]
|
|
1200
|
+
logger.info_on_rank_0(
|
|
1201
|
+
f'> {print_feature_name} features of {name} is monitored successfully')
|
|
1202
|
+
self.handles["L2_features"].append(handle)
|
|
1203
|
+
hooked_count += 1
|
|
1204
|
+
continue
|
|
1084
1205
|
|
|
1085
|
-
|
|
1206
|
+
return hooked_count
|
|
1086
1207
|
|
|
1208
|
+
def _patch_grad_sync(self):
|
|
1087
1209
|
if not self.wg_distribution:
|
|
1088
1210
|
return
|
|
1089
1211
|
if self.fsdp_wrapped_module:
|
|
@@ -1091,27 +1213,18 @@ class TrainerMon:
|
|
|
1091
1213
|
self._patch_fsdp_post_backward_hook()
|
|
1092
1214
|
return
|
|
1093
1215
|
|
|
1216
|
+
if self.fsdp2_wrapped_module:
|
|
1217
|
+
# patch fsdp2 _fully_shard._fsdp_collectives.foreach_reduce
|
|
1218
|
+
self._patch_fsdp2_foreach_reduce()
|
|
1219
|
+
return
|
|
1220
|
+
|
|
1094
1221
|
if self.monitor_mbs_grad:
|
|
1095
1222
|
self._hook_weights()
|
|
1096
1223
|
return
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
self.origin_start_grad_sync = Bucket.start_grad_sync
|
|
1100
|
-
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
|
|
1101
|
-
self.enable_megatron = True
|
|
1102
|
-
logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
|
|
1103
|
-
except ImportError:
|
|
1104
|
-
self.enable_megatron = False
|
|
1224
|
+
|
|
1225
|
+
self.optimizer_mon.patch_grad_sync(self)
|
|
1105
1226
|
|
|
1106
|
-
|
|
1107
|
-
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
1108
|
-
self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync
|
|
1109
|
-
_ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
|
|
1110
|
-
self.enable_megatron = True
|
|
1111
|
-
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
1112
|
-
except ImportError:
|
|
1113
|
-
self.enable_megatron = False | self.enable_megatron
|
|
1114
|
-
if self.enable_megatron:
|
|
1227
|
+
if self.enable_megatron or self.enable_deepspeed:
|
|
1115
1228
|
return
|
|
1116
1229
|
|
|
1117
1230
|
# default hook weights
|
|
@@ -1124,17 +1237,22 @@ class TrainerMon:
|
|
|
1124
1237
|
每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效,
|
|
1125
1238
|
因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。
|
|
1126
1239
|
"""
|
|
1240
|
+
|
|
1127
1241
|
def patch_post_backward_hook(_post_backward_hook):
|
|
1128
1242
|
def wrapper(state, handle, *unused):
|
|
1129
1243
|
grad_dict = {}
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1244
|
+
local_names = handle.flat_param._fqns
|
|
1245
|
+
offsets = handle._get_flat_param_offsets()
|
|
1246
|
+
shapes = handle.flat_param._shapes
|
|
1247
|
+
flat_prefix = next(self.flat_prefix_reverse_iter)
|
|
1248
|
+
for local_name, (start, end), local_shape in zip(local_names, offsets, shapes):
|
|
1249
|
+
grad_clip = handle.flat_param.grad[start:end + 1]
|
|
1250
|
+
grad = grad_clip.reshape(local_shape)
|
|
1251
|
+
total_name = f"{flat_prefix}{MonitorConst.FSDP_FLAT_SEP}{local_name}"
|
|
1252
|
+
if total_name not in self.origin2squash:
|
|
1253
|
+
logger.warning(f"{total_name} not in model.named_parameters(), skip.")
|
|
1134
1254
|
continue
|
|
1135
|
-
|
|
1136
|
-
offset += limit
|
|
1137
|
-
tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
1255
|
+
tag = self.name2tag.get(self.origin2squash[total_name], {}).get(MonitorConst.PRE_GRAD)
|
|
1138
1256
|
if tag is None:
|
|
1139
1257
|
continue
|
|
1140
1258
|
grad_dict[tag] = grad
|
|
@@ -1150,6 +1268,28 @@ class TrainerMon:
|
|
|
1150
1268
|
torch.distributed.fsdp._runtime_utils._post_backward_hook = \
|
|
1151
1269
|
patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook)
|
|
1152
1270
|
|
|
1271
|
+
def _patch_fsdp2_foreach_reduce(self):
|
|
1272
|
+
def patch_foreach_reduce(foreach_reduce):
|
|
1273
|
+
def wrapper(fsdp_params, unsharded_grads, *unused):
|
|
1274
|
+
grad_dict = {}
|
|
1275
|
+
for param, grad in zip(fsdp_params, unsharded_grads):
|
|
1276
|
+
tag = self.name2tag.get(self.origin2squash[param._param_fqn], {}).get(MonitorConst.PRE_GRAD)
|
|
1277
|
+
if tag is None:
|
|
1278
|
+
continue
|
|
1279
|
+
grad_dict[tag] = grad
|
|
1280
|
+
self.register_param_call_id("foreach_reduce", tag)
|
|
1281
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
1282
|
+
out = foreach_reduce(fsdp_params, unsharded_grads, *unused)
|
|
1283
|
+
return out
|
|
1284
|
+
return wrapper
|
|
1285
|
+
|
|
1286
|
+
logger.info("Patch fsdp2 foreach_reduce, collect pre_grad metrics.")
|
|
1287
|
+
import torch.distributed.fsdp._fully_shard._fsdp_param_group as _fsdp_param_group
|
|
1288
|
+
import torch.distributed.fsdp._fully_shard._fsdp_collectives as _fsdp_collectives
|
|
1289
|
+
self.fsdp2_foreach_reduce = _fsdp_collectives.foreach_reduce
|
|
1290
|
+
_fsdp_collectives.foreach_reduce = patch_foreach_reduce(_fsdp_collectives.foreach_reduce)
|
|
1291
|
+
importlib.reload(_fsdp_param_group) # 关键操作,不然会因为torch一开始就import foreach_reduce导致patch失效
|
|
1292
|
+
|
|
1153
1293
|
def _hook_weights(self):
|
|
1154
1294
|
"""
|
|
1155
1295
|
遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。
|
|
@@ -17,6 +17,7 @@ import re
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
19
|
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
|
+
from msprobe.pytorch.monitor.features import cal_entropy, cal_stable_rank
|
|
20
21
|
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
21
22
|
|
|
22
23
|
|
|
@@ -31,7 +32,8 @@ def squash_param_name(param_name, enable=True):
|
|
|
31
32
|
if not enable:
|
|
32
33
|
return param_name
|
|
33
34
|
name = ''
|
|
34
|
-
for pattern in ['layers
|
|
35
|
+
for pattern in ['^.*\.(layers?\..*)', '^.*\.(embeddings?\..*)', '^.*\.(final.*)', '^.*\.(output.*)',
|
|
36
|
+
'^.*\.(norm.*)']:
|
|
35
37
|
match = re.findall(pattern, param_name)
|
|
36
38
|
if match:
|
|
37
39
|
name += match[0]
|
|
@@ -184,3 +186,27 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
184
186
|
fun_metric = config_metric_registry.get(metric_name)
|
|
185
187
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
186
188
|
return out_dict
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def get_sr_metric(tag2tensor, out_dict=None):
|
|
192
|
+
if out_dict is None:
|
|
193
|
+
out_dict = {}
|
|
194
|
+
for tag, tensor in tag2tensor.items():
|
|
195
|
+
if "sr" not in tag:
|
|
196
|
+
continue
|
|
197
|
+
if tag not in out_dict:
|
|
198
|
+
out_dict[tag] = {}
|
|
199
|
+
sr, eig = cal_stable_rank(tensor)
|
|
200
|
+
out_dict[tag]['sr'] = sr
|
|
201
|
+
out_dict[tag]['kernel_norm'] = eig
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def get_entropy_metric(tag2tensor, out_dict=None):
|
|
205
|
+
if out_dict is None:
|
|
206
|
+
out_dict = {}
|
|
207
|
+
for tag, tensor in tag2tensor.items():
|
|
208
|
+
if tag not in out_dict:
|
|
209
|
+
out_dict[tag] = {}
|
|
210
|
+
entropy, softmax_max = cal_entropy(tensor)
|
|
211
|
+
out_dict[tag]['entropy'] = entropy
|
|
212
|
+
out_dict[tag]['softmax_max'] = softmax_max
|