mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,873 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import uuid
19
+ from collections import defaultdict
20
+ from datetime import datetime
21
+
22
+ import pytz
23
+ import mindspore as ms
24
+ from mindspore import Tensor, mint
25
+ from mindspore import nn, _no_grad
26
+ from mindspore.communication import get_rank
27
+
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.const import MonitorConst
30
+ from msprobe.core.common.file_utils import load_json, save_json
31
+ from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
32
+ is_skip_step, get_metrics, get_single_metrics, get_target_output_dir
33
+ from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec
34
+ from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \
35
+ CSVWriterWithAD, BaseWriterWithAD, WriterInput
36
+ from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
37
+ get_process_group
38
+
39
+ FORMAT_MAPPING = {
40
+ MonitorConst.CSV: CSVWriterWithAD,
41
+ MonitorConst.API: BaseWriterWithAD
42
+ }
43
+
44
+
45
+ def get_output_base_dir():
46
+ return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
47
+
48
+
49
+ def get_param_struct(param):
50
+ res = {}
51
+ if isinstance(param, (tuple, list)):
52
+ res['config'] = f'{type(param).__name__}[{len(param)}]'
53
+ for i, x in enumerate(param):
54
+ res[i] = f'size={tuple(x.shape)}, dtype={x.dtype}' if isinstance(x, Tensor) else f'{type(x)}'
55
+ elif isinstance(param, Tensor):
56
+ res['config'] = 'tensor'
57
+ res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}'
58
+ else:
59
+ res['config'] = f'{type(param)}'
60
+ logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}')
61
+ return res
62
+
63
+
64
+ def param_is_not_tensor_parallel_duplicate(param, tp_group):
65
+ return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
66
+ mint.distributed.get_rank(group=tp_group) == 0
67
+ )
68
+
69
+
70
+ def param_is_data_parallel_duplicate(dp_group):
71
+ return mint.distributed.get_rank(group=dp_group) != 0
72
+
73
+
74
+ def squash_param_name(param_name):
75
+ for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
76
+ match = re.findall(pattern, param_name)
77
+ if match:
78
+ return match[0]
79
+ return param_name
80
+
81
+
82
+ # Used For Module Forward & Backward Collect
83
+ class ModuleHookContext:
84
+ def __init__(self, module_name) -> None:
85
+ self.step = 0
86
+ self.micro_step = 0
87
+ self.actv = defaultdict(dict)
88
+ self.actvgrad = []
89
+ self.module_name = module_name
90
+ self.struct = {}
91
+ self.format_by_arg = {}
92
+ self.verified = False
93
+ self.focused_in_col = 0
94
+ self.focused_out_col = 0
95
+ self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found
96
+
97
+ def set_format_by_arg(self, key_name: str, target_config: dict):
98
+ cared = target_config.get(self.module_name, self.struct)
99
+ if key_name in cared:
100
+ if isinstance(cared[key_name], dict):
101
+ # current cared is self.struct
102
+ config = cared[key_name].get('config')
103
+ self.format_by_arg[key_name] = config
104
+ else:
105
+ # current cared is target_config[self.module_name]
106
+ self.format_by_arg[key_name] = cared[key_name]
107
+ elif key_name in ['input', 'input_grad']:
108
+ self.ignore_in = True
109
+
110
+ def reset(self):
111
+ self.actv.clear()
112
+ self.actvgrad.clear()
113
+
114
+
115
+ start_step = 0
116
+
117
+
118
+ # Used For Optimizer Weight Grad & M/V Collect
119
+ class OptimizerContext:
120
+ def __init__(self) -> None:
121
+ self.step = start_step
122
+ self.param_mg_direction = defaultdict(float)
123
+ self.param_adam_update = defaultdict()
124
+ self.param_adam_ratio = defaultdict()
125
+ self.param_weight_grad = defaultdict()
126
+ self.param_exp_avg = defaultdict()
127
+ self.exp_avg_metric = {}
128
+ self.param_exp_avg_sq = defaultdict()
129
+ self.exp_avg_sq_metric = {}
130
+ self.metric_dict = {}
131
+ self.param_metric = {}
132
+
133
+ def reset(self) -> None:
134
+ self.param_mg_direction.clear()
135
+ self.param_adam_update.clear()
136
+ self.param_adam_ratio.clear()
137
+ self.param_weight_grad.clear()
138
+ self.param_exp_avg.clear()
139
+ self.exp_avg_metric.clear()
140
+ self.param_exp_avg_sq.clear()
141
+ self.exp_avg_sq_metric.clear()
142
+ self.metric_dict.clear()
143
+ self.param_metric.clear()
144
+
145
+
146
+ # Used For Weight Grad Collect
147
+ class GradContext:
148
+ def __init__(self) -> None:
149
+ self.pre = {}
150
+ self.post = {}
151
+ self.acc_metric = {}
152
+ self.acc = {}
153
+ self.actv = {}
154
+
155
+ def reset(self):
156
+ self.pre.clear()
157
+ self.post.clear()
158
+ self.acc_metric.clear()
159
+ self.acc.clear()
160
+ self.actv.clear()
161
+
162
+
163
+ class CommunicationContext:
164
+ def __init__(self) -> None:
165
+ self.data = {}
166
+
167
+ @staticmethod
168
+ def _agg(data):
169
+ aggregated_data = {}
170
+ for tag, op2tensorlist in data.items():
171
+ aggregated_data[tag] = {}
172
+ for op, tensorlist in op2tensorlist.items():
173
+ aggregated_data[tag][op] = op_aggregate(op, tensorlist)
174
+ return aggregated_data
175
+
176
+ def reset(self):
177
+ self.data = {}
178
+
179
+ def aggregate(self):
180
+ self.data = self._agg(self.data)
181
+
182
+
183
+ class TrainerMon:
184
+ def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None:
185
+ # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置
186
+ self.config_file_path = config_file_path
187
+ self.process_group = process_group
188
+ self.params_have_main_grad = params_have_main_grad
189
+ self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开
190
+ self.config = load_json(config_file_path)
191
+ validate_config(self.config)
192
+
193
+ local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区
194
+ cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S')
195
+ self.unique_id = str(uuid.uuid4())[:8]
196
+ self.output_base_dir = get_output_base_dir()
197
+ time_tags = self.config.get("append_output", [])
198
+ try:
199
+ self.rank = get_rank()
200
+ if time_tags:
201
+ output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1])
202
+ if str(self.rank) in output_append_dirs:
203
+ self.tensorboard_dir = output_append_dirs[str(self.rank)]
204
+ logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}")
205
+ else:
206
+ self.tensorboard_dir = os.path.join(self.output_base_dir,
207
+ f"{cur_time}-rank{self.rank}-{self.unique_id}")
208
+ except Exception as e:
209
+ self.rank = 0
210
+ self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}")
211
+
212
+ self.pp_stage = 0
213
+ self.group_mates = [0]
214
+
215
+ # TYPE2: 只会在set_monitor()主调中赋值的变量
216
+ self.model = None
217
+ self.vpp = False
218
+ self.dp_group = None
219
+ self.tp_group = None
220
+ self.micro_batch_number = 1
221
+
222
+ # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量
223
+ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext)
224
+ self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext)
225
+ self.optimizer_context = defaultdict(OptimizerContext)
226
+ self.cc_context = defaultdict(CommunicationContext)
227
+ self.grad_context = GradContext()
228
+ self.handles = defaultdict(list)
229
+ self.param2name = defaultdict(str)
230
+ self.name2index = defaultdict()
231
+ self.name2indices = defaultdict()
232
+ self.name2param = {}
233
+ self.duplicate_param = {}
234
+ self.name2tag = {}
235
+ self.param_name_call_id = {}
236
+ self.call_id = 0
237
+ self.module_struct = defaultdict(dict)
238
+ self.grad_accs = []
239
+ self.weight_hooked = False
240
+ self.optimizer_hooked = False
241
+ self.param_registered = False
242
+ self.struct_printed = False
243
+
244
+ # 动静态区分
245
+ self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true'
246
+ if self.dynamic_enable:
247
+ logger.warning(f"DYNAMIC_MONITOR is set, "
248
+ f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}")
249
+ self.monitoring = False
250
+ else:
251
+ self.set_config()
252
+ # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启
253
+ if self.collect_times > 0:
254
+ self.monitoring = True
255
+
256
+ def set_config(self):
257
+ self.start_step = self.config.get("start_step", 0)
258
+ self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集
259
+ self.step_interval = self.config.get("step_interval", 1)
260
+ self.has_collect_times = 0 # 重设采集计数器
261
+ self.print_struct = self.config.get("print_struct", False)
262
+ self.targets = self.config.get("targets", None)
263
+ self.is_select = self.config.get("is_select", False)
264
+ self.module_rank_list = self.config.get("module_ranks", [])
265
+ self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
266
+ self.eps = self.config.get('eps', 1e-8)
267
+ self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan...
268
+ self.ndigits = self.config.get('ndigits', 6)
269
+ self.all_xy = self.config.get('all_xy', False)
270
+ self.xy_distribution = self.config.get('xy_distribution', False)
271
+ self.forward_only = self.config.get('forward_only', False)
272
+ self.backward_only = self.config.get('backward_only', False)
273
+ self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam
274
+ self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam
275
+ self.wg_distribution = self.config.get("wg_distribution", False)
276
+ self.param_distribution = self.config.get("param_distribution", False)
277
+ self.mg_direction = self.config.get('mg_direction', False) # main grad direction
278
+ self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops
279
+ if not self.cc_distribution.get('enable', False):
280
+ self.cc_log_only = False
281
+ else:
282
+ self.cc_codeline = self.cc_distribution.get('cc_codeline', [])
283
+ self.cc_log_only = self.cc_distribution.get('cc_log_only', False)
284
+ self.cc_logged_stack = defaultdict(set)
285
+ self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False)
286
+ self.common_info()
287
+
288
+ # 初始化AnomalyData工厂
289
+ alert_setting = self.config.get('alert', {"rules": []})
290
+ self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"])
291
+ self.anomaly_data_factory = None
292
+ if alert_setting.get('dump', False):
293
+ self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates)
294
+
295
+ # 初始化writer, 创建输出目录
296
+ if self.format not in FORMAT_MAPPING:
297
+ logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}")
298
+ self.format = MonitorConst.CSV
299
+ writer = FORMAT_MAPPING[self.format]
300
+ self.step_count_per_record = self.config.get('step_count_per_record', 1)
301
+ self.summary_writer = writer(
302
+ WriterInput(
303
+ self.tensorboard_dir,
304
+ self.alert_rules,
305
+ self.unique_id,
306
+ self.anomaly_data_factory,
307
+ self.ndigits,
308
+ self.step_count_per_record
309
+ )
310
+ )
311
+
312
+ def common_info(self):
313
+ if not self.xy_distribution:
314
+ logger.info("> module input/output input_grad/output_grad is not monitored. ")
315
+ if self.forward_only:
316
+ logger.info("> only module forward is monitored. ")
317
+ if not self.ur_distribution:
318
+ logger.info("> update vector and ratio vector of adam is not monitored. ")
319
+ if not self.mv_distribution:
320
+ logger.info("> momentum and variance of adam is not monitored. ")
321
+ if not self.wg_distribution:
322
+ logger.info("> weight grad of specified module is not monitored. ")
323
+ if not self.mg_direction:
324
+ logger.info('> grad and momentum direction will not be compared.')
325
+ if not self.cc_distribution.get('enable', False):
326
+ logger.info("> cc operator is not monitored.")
327
+
328
+ def set_monitor(
329
+ self,
330
+ model,
331
+ optimizer,
332
+ grad_acc_steps=1,
333
+ tp_group=None,
334
+ dp_group=None,
335
+ start_iteration=0
336
+ ):
337
+ global start_step
338
+ start_step = start_iteration
339
+ self.micro_batch_number = grad_acc_steps
340
+ self.dp_group = dp_group
341
+ self.tp_group = tp_group
342
+ self.hook_step_final(optimizer)
343
+ if not isinstance(model, list):
344
+ model = [model]
345
+ self.model = model
346
+ if len(model) > 1:
347
+ self.vpp = True
348
+ logger.info('vpp enabled')
349
+ if not self.dynamic_enable:
350
+ self.register_hooks(optimizer)
351
+
352
+ def hook_step_final(self, optimizer):
353
+ def step_final_hook(optimizer, *args, **kwargs):
354
+ context = self.optimizer_context[optimizer]
355
+ # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False
356
+ if self.monitoring:
357
+ module_rank_valid = self.is_target_rank()
358
+ step_condition = (context.step >= self.start_step and (
359
+ context.step - self.start_step) % self.step_interval == 0)
360
+ if module_rank_valid and step_condition:
361
+ self.has_collect_times += 1
362
+ self.write_xy_tb(context.step)
363
+ self.write_grad_tb(context.step)
364
+ self.write_mv_tb(context)
365
+ self.write_param_tb(context)
366
+
367
+ if context.metric_dict:
368
+ self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other')
369
+ context.metric_dict.clear()
370
+
371
+ self.summary_writer.clear_anomalies()
372
+ self.call_id = 0
373
+ self.param_name_call_id.clear()
374
+
375
+ if self.has_collect_times >= self.collect_times:
376
+ self._remove_all_hooks_final(optimizer)
377
+
378
+ context.step += 1
379
+ self.dynamic_monitor(optimizer)
380
+
381
+ optimizer.register_forward_hook(step_final_hook)
382
+ return
383
+
384
+ def dynamic_monitor(self, optimizer):
385
+ """
386
+ If dynamic monitor enabled and config.json updated,
387
+ remove hooks and register new hooks according to new configuration.
388
+ """
389
+ context = self.optimizer_context[optimizer]
390
+ if not self.dynamic_enable:
391
+ return
392
+ try:
393
+ # 如果文件时间戳没变, 可以不读取节省时间
394
+ config_timestamp = os.path.getmtime(self.config_file_path)
395
+ if config_timestamp == self.config_timestamp:
396
+ return
397
+ # 更新config文件最新修改时间戳
398
+ self.config_timestamp = config_timestamp
399
+ config = load_json(self.config_file_path)
400
+ except Exception as e:
401
+ logger.error(f"get config.json wrong because {e}, not updated, please check!!!")
402
+ return
403
+
404
+ if config.get("dynamic_on", False):
405
+ try:
406
+ validate_config(config)
407
+ self.config = config
408
+ self.set_config()
409
+ self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始
410
+ logger.warning(f"config is updated at step{context.step - 1}, "
411
+ f"will start new hook at step{context.step}.")
412
+ except Exception as e:
413
+ logger.error(f"set config wrong because {e}, not updated, please check!!!")
414
+ return
415
+
416
+ self._remove_all_hooks()
417
+ self.register_hooks(optimizer)
418
+
419
+ def register_hooks(self, optimizer):
420
+ self._register_param_name()
421
+ self.hook_modules()
422
+ self.hook_optimizer(optimizer)
423
+ self._patch_grad_sync()
424
+ if self.cc_distribution.get('enable', False):
425
+ self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self))
426
+ api_register.redirect_api()
427
+ self.monitoring = True
428
+
429
+ def hook_modules(self):
430
+ if not self.is_target_rank():
431
+ return
432
+ module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key]
433
+
434
+ for key in module_in_all_stage:
435
+ struct = self.targets.pop(key)
436
+ self.targets.update(
437
+ {f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))})
438
+
439
+ hooked_count = 0
440
+ for vpp_stage, model_chunk in enumerate(self.model):
441
+ if not isinstance(model_chunk, nn.Cell):
442
+ logger.info("Target Model is not Cell")
443
+ continue
444
+ vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}'
445
+ targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys()
446
+ hooked_count += self._hook_module(targets, model_chunk, vpp_stage)
447
+ logger.info(f"> {hooked_count} modules are monitored.")
448
+
449
+ def hook_optimizer(self, optimizer):
450
+ def optimizer_pre_hook_function(opt, grad_names, gradients):
451
+ context = self.optimizer_context[opt]
452
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
453
+ self.collect_times):
454
+ return
455
+ gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients
456
+ is_select = self.is_select
457
+ for idx, grad in enumerate(gradient_list):
458
+ grad_name = grad_names[idx]
459
+ if is_select and grad_name not in self.targets:
460
+ continue
461
+ get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad)
462
+
463
+ if self.mv_distribution:
464
+ # fetch mean
465
+ for param in m_list:
466
+ name = param.name
467
+ if is_select and name not in self.targets:
468
+ continue
469
+ get_single_metrics(self.ops, name, param, context.exp_avg_metric)
470
+ # fetch variance
471
+ for param in v_list:
472
+ name = param.name
473
+ if is_select and name not in self.targets:
474
+ continue
475
+ get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric)
476
+ if self.param_distribution:
477
+ for param in param_list:
478
+ get_single_metrics(self.ops, param.name, param, context.param_metric)
479
+ self.generate_wgrad_metrics()
480
+ metric_dict = {}
481
+ for cc in self.cc_context.values():
482
+ cc.aggregate()
483
+ metric_dict.update(cc.data)
484
+ cc.reset()
485
+
486
+ if not metric_dict:
487
+ return
488
+ context.metric_dict = metric_dict
489
+ return
490
+
491
+ def optimizer_pre_hook_wrapper(func, grad_names):
492
+ def wrapper(opt, gradients):
493
+ return func(opt, grad_names, gradients)
494
+ return wrapper
495
+
496
+ if self.optimizer_hooked or not self.is_target_rank():
497
+ return
498
+
499
+ m_list = []
500
+ v_list = []
501
+ param_list = []
502
+ grad_names = []
503
+ for param in optimizer.get_parameters():
504
+ if MonitorConst.EXP_AVG_SQ in param.name:
505
+ v_list.append(param)
506
+ elif MonitorConst.EXP_AVG in param.name:
507
+ m_list.append(param)
508
+ elif param.name in ['global_step', 'learning_rate']:
509
+ pass
510
+ else:
511
+ param_list.append(param)
512
+ grad_names.append(param.name)
513
+
514
+ handle = optimizer.register_forward_pre_hook(
515
+ optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names))
516
+ self.handles['optimizer'].append(handle)
517
+ self.optimizer_hooked = True
518
+ return
519
+
520
+ def generate_wgrad_metrics(self):
521
+ if not self.wg_distribution:
522
+ return {}, {}
523
+
524
+ if self.weight_hooked:
525
+ try:
526
+ get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
527
+ except Exception as e:
528
+ logger.warning(f"An error occurred while generating wgrad pre metrics")
529
+ return {}, {}
530
+
531
+ grad_dict = {}
532
+ for param, name in self.param2name.items():
533
+ if self.duplicate_param.get(name, False):
534
+ continue
535
+ grad = param.main_grad if self.params_have_main_grad else param.grad
536
+ if grad is None:
537
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
538
+ continue
539
+ tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
540
+ self._register_param_call_id("hook_optimizer", tag)
541
+ grad_dict[tag] = grad
542
+ try:
543
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
544
+ except Exception as e:
545
+ logger.warning(f"An error occurred while generating wgrad post metrics")
546
+ return {}, {}
547
+ return self.grad_context.post, self.grad_context.pre
548
+
549
+ def write_xy_tb(self, step):
550
+ if not self.xy_distribution:
551
+ return
552
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
553
+ if len(fwd_context.actv) == 0:
554
+ continue
555
+ self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv')
556
+ fwd_context.actv.clear()
557
+ if self.grad_context.actv:
558
+ self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad')
559
+
560
+ def write_param_tb(self, opt_context):
561
+ if not self.param_distribution:
562
+ return
563
+ self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param')
564
+
565
+ def write_mv_tb(self, opt_context):
566
+ if not self.mv_distribution:
567
+ return
568
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg')
569
+ self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq')
570
+
571
+ def write_grad_tb(self, step):
572
+ if not self.wg_distribution:
573
+ return
574
+
575
+ self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced')
576
+ self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
577
+
578
+ def is_target_rank(self):
579
+ if self.module_rank_list and (self.rank not in self.module_rank_list):
580
+ return False
581
+ return True
582
+
583
+ def build_tbtag_tensor_map(self, module_name, tag, tensor):
584
+ metrics = {}
585
+ key = get_summary_writer_tag_name(module_name, tag, str(self.rank))
586
+ if isinstance(tensor, Tensor):
587
+ self._register_param_call_id("_hook_module", key)
588
+ metrics[key] = tensor
589
+ return metrics
590
+
591
+ def _register_param_name(self):
592
+ for vpp_stage, model_chunk in enumerate(self.model):
593
+ prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}'
594
+ self._register_chunk(model_chunk, prefix)
595
+
596
+ def _register_chunk(self, model_chunk, prefix):
597
+ index = 0
598
+ for param in model_chunk.get_parameters():
599
+ param_name = param.name
600
+ if not param.requires_grad:
601
+ continue
602
+ if self._is_target_param(param_name, param, prefix):
603
+ name = prefix + squash_param_name(param_name)
604
+ if name in self.param2name.values():
605
+ name = prefix + param_name
606
+ self.param2name[param] = name
607
+ self.name2param[name] = param
608
+ self.name2index[name] = index
609
+
610
+ if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group):
611
+ self.duplicate_param[name] = True
612
+ if self.dp_group and param_is_data_parallel_duplicate(self.dp_group):
613
+ self.duplicate_param[name] = True
614
+ self.name2tag[name] = {
615
+ MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank),
616
+ MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank)
617
+ }
618
+ index += 1
619
+
620
+ def _hook_module(self, target_names, module, vpp_stage=''):
621
+ if not isinstance(module, nn.Cell):
622
+ # nothing to hook
623
+ return 0
624
+
625
+ def fwd_hook_fun(module, module_input, module_output, name):
626
+ if module not in self.module_fwd_hook_context_by_module:
627
+ self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name)
628
+ context: ModuleHookContext = self.module_fwd_hook_context_by_module[module]
629
+ if not context.struct:
630
+ context.struct = {
631
+ MonitorConst.ACTV_IN: get_param_struct(module_input),
632
+ MonitorConst.ACTV_OUT: get_param_struct(module_output)
633
+ }
634
+ if self.print_struct:
635
+ self.module_struct[context.module_name].update(context.struct)
636
+ return
637
+ if not module.training:
638
+ return
639
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
640
+ self.collect_times):
641
+ step_accumulates_one(context, self.micro_batch_number)
642
+ return
643
+ if not context.format_by_arg:
644
+ context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets)
645
+ context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets)
646
+ if not context.format_by_arg:
647
+ return
648
+ if not context.verified:
649
+ if not context.ignore_in:
650
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN],
651
+ module_input, context.module_name,
652
+ MonitorConst.ACTV_IN)
653
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT],
654
+ module_output, context.module_name,
655
+ MonitorConst.ACTV_OUT)
656
+ context.verified = True
657
+
658
+ tbtag_tensor_map = {}
659
+ if not context.ignore_in:
660
+ cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col]
661
+ tbtag_tensor_map.update(
662
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN,
663
+ cared_input))
664
+ cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col]
665
+ tbtag_tensor_map.update(
666
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT,
667
+ cared_output))
668
+ try:
669
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv)
670
+ except Exception as e:
671
+ logger.warning(f"An error occurred while generating forward activation metrics")
672
+
673
+ step_accumulates_one(context, self.micro_batch_number)
674
+ return
675
+
676
+ def bwd_hook_fun(module, input_grad, output_grad):
677
+ context: ModuleHookContext = self.module_bwd_hook_context_by_module[module]
678
+ if not context.struct:
679
+ context.struct = {
680
+ MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad),
681
+ MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad)
682
+ }
683
+ if self.print_struct:
684
+ self.module_struct[context.module_name].update(context.struct)
685
+ return
686
+
687
+ if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times,
688
+ self.collect_times):
689
+ step_accumulates_one(context, self.micro_batch_number)
690
+ return
691
+
692
+ if not context.format_by_arg:
693
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets)
694
+ context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets)
695
+ if not context.format_by_arg:
696
+ return
697
+ if not context.verified:
698
+ if not context.ignore_in:
699
+ context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN],
700
+ input_grad, context.module_name,
701
+ MonitorConst.ACTVGRAD_IN)
702
+ context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT],
703
+ output_grad, context.module_name,
704
+ MonitorConst.ACTVGRAD_OUT)
705
+ context.verified = True
706
+
707
+ tbtag_tensor_map = {}
708
+ if not context.ignore_in:
709
+ cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col]
710
+ tbtag_tensor_map.update(
711
+ self.build_tbtag_tensor_map(
712
+ f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad))
713
+ cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col]
714
+ tbtag_tensor_map.update(
715
+ self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT,
716
+ cared_output_grad))
717
+
718
+ if context.micro_step == 0 and context.actvgrad:
719
+ logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, "
720
+ f"maybe something wrong happened. Now clear it.")
721
+ context.actvgrad.clear()
722
+ try:
723
+ get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv)
724
+ except Exception as e:
725
+ logger.warning(f"An error occurred while generating backward activation metrics: {e}")
726
+
727
+ step_accumulates_one(context, self.micro_batch_number)
728
+ return
729
+
730
+ def fwd_hook_fun_wrapper(fwd_hook_fun, name):
731
+ def wrapper(module, module_input, module_output):
732
+ return fwd_hook_fun(module, module_input, module_output, name)
733
+ return wrapper
734
+
735
+ if self.backward_only and self.forward_only:
736
+ logger.warning('not enable backward_only and forward_only simultaneously')
737
+ hooked_count = 0
738
+ if self.xy_distribution or self.print_struct:
739
+ for module_name, submodule in module.cells_and_names():
740
+ name = self._is_target_module(module_name, target_names, vpp_stage)
741
+ if not name:
742
+ continue
743
+ if not self.backward_only:
744
+ handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name))
745
+ self.handles['xy'].append(handle)
746
+ if not self.forward_only:
747
+ handle = submodule.register_backward_hook(bwd_hook_fun)
748
+ self.handles['xy'].append(handle)
749
+ self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name)
750
+ logger.info(f"> {name} is monitored successfully")
751
+ hooked_count += 1
752
+ return hooked_count
753
+
754
+ def _patch_grad_sync(self):
755
+ if not self.wg_distribution:
756
+ return
757
+ self._hook_weights()
758
+
759
+ def _hook_weights(self):
760
+ context = self.grad_context
761
+
762
+ @_no_grad()
763
+ def param_hook(grad, context_dict, param, key):
764
+ param.micro_step += 1
765
+ self._register_param_call_id("param_hook", key)
766
+ if param.micro_step == self.micro_batch_number:
767
+ param.micro_step = 0
768
+ context_dict[key] = grad
769
+
770
+ def param_hook_wrapper(param_hook, context_dict, param, key):
771
+ def wrapper(grad):
772
+ return param_hook(grad, context_dict, param, key)
773
+ return wrapper
774
+
775
+ for param, name in self.param2name.items():
776
+ key = get_summary_writer_tag_name(name, 'acc_grad', self.rank)
777
+ setattr(param, 'micro_step', 0)
778
+ handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key))
779
+ self.handles['wgrads'].append(handle)
780
+ self.weight_hooked = True
781
+
782
+ def _is_target_param(self, param_name, param, prefix):
783
+ if not self.targets:
784
+ return True
785
+ squash_name = prefix + squash_param_name(param_name)
786
+ name = prefix + param_name
787
+ for target in self.targets.keys():
788
+ if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target):
789
+ setattr(param, "zero_out_wgrad", True)
790
+ return True
791
+ return False
792
+
793
+ def _is_target_module(self, module_name, targets, vpp_stage):
794
+ if self.all_xy or self.print_struct:
795
+ return vpp_stage + squash_param_name(module_name)
796
+ for pattern in [
797
+ vpp_stage + squash_param_name(module_name),
798
+ vpp_stage + module_name,
799
+ ]:
800
+ if pattern in targets:
801
+ return pattern
802
+ return ""
803
+
804
+ def _register_param_call_id(self, hook_name: str, key: str):
805
+ """
806
+ :param hook_name:
807
+ :param key: str, '0:relu_0/output_grad'
808
+ :return:
809
+ """
810
+ logger.debug(f"{hook_name} {key}: {self.call_id}")
811
+ self.param_name_call_id[key] = self.call_id
812
+ self.call_id += 1
813
+
814
+ def _remove_all_hooks(self):
815
+ # 清空hook handle
816
+ for handle in self.handles['xy']:
817
+ handle.remove()
818
+ self.handles['xy'].clear()
819
+ # 清空对应context缓存
820
+ for _, fwd_context in self.module_fwd_hook_context_by_module.items():
821
+ fwd_context.reset()
822
+ for _, bwd_context in self.module_bwd_hook_context_by_module.items():
823
+ bwd_context.reset()
824
+ self.grad_context.reset() # 权重梯度和激活值梯度都在这
825
+
826
+ for handle in self.handles['wgrads']:
827
+ handle.remove()
828
+ self.handles['wgrads'].clear()
829
+ self.weight_hooked = False
830
+
831
+ if self.optimizer_hooked:
832
+ for handle in self.handles['optimizer']:
833
+ handle.remove()
834
+ self.handles['optimizer'].clear()
835
+ for _, context in self.optimizer_context.items():
836
+ context.reset()
837
+ self.optimizer_hooked = False
838
+
839
+ for handle in self.handles['cc']:
840
+ handle.remove()
841
+ self.handles['cc'].clear()
842
+ api_register.restore_api()
843
+ for _, context in self.cc_context.items():
844
+ context.reset()
845
+
846
+ # 清空节点缓存
847
+ self.param2name.clear()
848
+ self.name2index.clear()
849
+ self.name2indices.clear()
850
+ self.name2param.clear()
851
+ self.duplicate_param.clear()
852
+ self.name2tag.clear()
853
+ self.module_struct.clear()
854
+ self.grad_accs.clear()
855
+
856
+ # 关闭采集状态
857
+ self.monitoring = False
858
+
859
+ def _remove_all_hooks_final(self, optimizer):
860
+ if self.dynamic_enable:
861
+ # 结束后自动重置dynamic_on为False等待用户手动开启
862
+ try:
863
+ config = load_json(self.config_file_path)
864
+ config['dynamic_on'] = False
865
+ save_json(self.config_file_path, config, indent=2)
866
+ config_timestamp = os.path.getmtime(self.config_file_path)
867
+ self.config_timestamp = config_timestamp
868
+ logger.info(
869
+ "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config")
870
+ except Exception as e:
871
+ logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!")
872
+ logger.info("Finish monitor")
873
+ self._remove_all_hooks()