mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
  3. msprobe/core/common/const.py +3 -0
  4. msprobe/core/common/file_utils.py +45 -5
  5. msprobe/core/common/utils.py +117 -13
  6. msprobe/core/common_config.py +15 -1
  7. msprobe/core/compare/acc_compare.py +21 -9
  8. msprobe/core/compare/compare_cli.py +10 -2
  9. msprobe/core/compare/merge_result/merge_result.py +1 -1
  10. msprobe/core/compare/utils.py +8 -2
  11. msprobe/core/config_check/checkers/base_checker.py +2 -0
  12. msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
  13. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
  14. msprobe/core/config_check/config_check_cli.py +1 -1
  15. msprobe/core/config_check/config_checker.py +1 -2
  16. msprobe/core/data_dump/data_collector.py +4 -1
  17. msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
  19. msprobe/core/debugger/precision_debugger.py +13 -8
  20. msprobe/core/hook_manager.py +112 -82
  21. msprobe/core/monitor/utils.py +338 -0
  22. msprobe/core/service.py +2 -1
  23. msprobe/core/single_save/single_comparator.py +5 -3
  24. msprobe/docs/01.installation.md +1 -0
  25. msprobe/docs/05.data_dump_PyTorch.md +4 -4
  26. msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
  27. msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
  29. msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
  30. msprobe/docs/12.overflow_check_PyTorch.md +3 -2
  31. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  32. msprobe/docs/14.data_parse_PyTorch.md +35 -32
  33. msprobe/docs/21.visualization_PyTorch.md +9 -8
  34. msprobe/docs/22.visualization_MindSpore.md +1 -0
  35. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  36. msprobe/docs/24.code_mapping_Mindspore.md +6 -5
  37. msprobe/docs/31.config_check.md +15 -5
  38. msprobe/docs/33.generate_operator_MindSpore.md +2 -2
  39. msprobe/docs/34.RL_collect.md +18 -9
  40. msprobe/docs/35.nan_analyze.md +4 -3
  41. msprobe/docs/FAQ.md +3 -0
  42. msprobe/docs/img/ms_layer.png +0 -0
  43. msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
  44. msprobe/mindspore/cell_processor.py +35 -14
  45. msprobe/mindspore/code_mapping/bind.py +23 -4
  46. msprobe/mindspore/code_mapping/graph_parser.py +6 -4
  47. msprobe/mindspore/common/utils.py +3 -0
  48. msprobe/mindspore/compare/common_dir_compare.py +32 -12
  49. msprobe/mindspore/compare/ms_graph_compare.py +7 -2
  50. msprobe/mindspore/compare/utils.py +9 -1
  51. msprobe/mindspore/debugger/debugger_config.py +13 -11
  52. msprobe/mindspore/debugger/precision_debugger.py +67 -45
  53. msprobe/mindspore/dump/dump_tool_factory.py +2 -0
  54. msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
  55. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
  56. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
  57. msprobe/mindspore/dump/jit_dump.py +6 -3
  58. msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
  59. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
  60. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
  61. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
  62. msprobe/mindspore/mindspore_service.py +2 -2
  63. msprobe/mindspore/monitor/common_func.py +1 -1
  64. msprobe/mindspore/monitor/module_hook.py +3 -3
  65. msprobe/mindspore/monitor/utils.py +0 -252
  66. msprobe/mindspore/ms_config.py +0 -1
  67. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  68. msprobe/nan_analyze/graph.py +4 -0
  69. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
  70. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
  71. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
  72. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
  73. msprobe/pytorch/common/utils.py +0 -16
  74. msprobe/pytorch/compare/pt_compare.py +5 -0
  75. msprobe/pytorch/debugger/debugger_config.py +12 -5
  76. msprobe/pytorch/debugger/precision_debugger.py +8 -1
  77. msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
  78. msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
  79. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
  80. msprobe/pytorch/hook_module/hook_module.py +9 -9
  81. msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
  82. msprobe/pytorch/monitor/csv2tb.py +3 -10
  83. msprobe/pytorch/monitor/features.py +5 -0
  84. msprobe/pytorch/monitor/module_hook.py +6 -7
  85. msprobe/pytorch/monitor/module_metric.py +0 -3
  86. msprobe/pytorch/monitor/optimizer_collect.py +1 -1
  87. msprobe/pytorch/monitor/utils.py +1 -317
  88. msprobe/pytorch/online_dispatch/dispatch.py +1 -1
  89. msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
  90. msprobe/pytorch/parse_tool/lib/utils.py +2 -4
  91. msprobe/visualization/graph_service.py +1 -1
  92. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
  93. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
  94. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
  95. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import os
18
18
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
19
19
  from msprobe.core.common.exceptions import MsprobeException
20
20
  from msprobe.core.common.file_utils import FileChecker, load_json
21
- from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
21
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step, ThreadSafe
22
22
  from msprobe.core.common_config import CommonConfig
23
23
 
24
24
 
@@ -27,13 +27,14 @@ class BasePrecisionDebugger:
27
27
  tasks_not_need_debugger = [Const.GRAD_PROBE]
28
28
 
29
29
  def __new__(cls, *args, **kwargs):
30
- if cls._instance is None:
31
- cls._instance = super(BasePrecisionDebugger, cls).__new__(cls)
32
- cls._instance.config = None
33
- cls._instance.enable_dataloader = False
34
- cls._instance.initialized = False
35
- cls.service = None
36
- cls.first_start = False
30
+ if not cls._instance:
31
+ with ThreadSafe():
32
+ if not cls._instance:
33
+ cls._instance = super(BasePrecisionDebugger, cls).__new__(cls)
34
+ cls._instance.config = None
35
+ cls._instance.initialized = False
36
+ cls.service = None
37
+ cls.first_start = False
37
38
  return cls._instance
38
39
 
39
40
  def __init__(
@@ -83,11 +84,13 @@ class BasePrecisionDebugger:
83
84
  raise NotImplementedError("Subclass must implement _get_task_config")
84
85
 
85
86
  @classmethod
87
+ @ThreadSafe.synchronized
86
88
  def forward_backward_dump_end(cls):
87
89
  instance = cls._instance
88
90
  instance.stop()
89
91
 
90
92
  @classmethod
93
+ @ThreadSafe.synchronized
91
94
  def set_init_step(cls, step):
92
95
  instance = cls._instance
93
96
  if not instance:
@@ -97,6 +100,7 @@ class BasePrecisionDebugger:
97
100
  instance.service.loop = 0
98
101
 
99
102
  @classmethod
103
+ @ThreadSafe.synchronized
100
104
  def register_custom_api(cls, module, api, api_prefix=None):
101
105
  if not api_prefix:
102
106
  api_prefix = getattr(module, "__name__", "Custom")
@@ -112,6 +116,7 @@ class BasePrecisionDebugger:
112
116
  instance.service.register_custom_api(module, api, api_prefix)
113
117
 
114
118
  @classmethod
119
+ @ThreadSafe.synchronized
115
120
  def restore_custom_api(cls, module, api):
116
121
  if not hasattr(module, api):
117
122
  raise MsprobeException(
@@ -13,12 +13,14 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
-
17
- from abc import ABC, abstractmethod
18
16
  import os
17
+ import threading
18
+ from abc import ABC, abstractmethod
19
+ from collections import defaultdict
19
20
 
21
+ from msprobe.core.common.log import logger
20
22
  from msprobe.core.common.runtime import Runtime
21
- from msprobe.core.common.utils import Const
23
+ from msprobe.core.common.utils import Const, ThreadSafe
22
24
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
23
25
 
24
26
 
@@ -31,7 +33,7 @@ class HookSet:
31
33
 
32
34
 
33
35
  class BaseHookManager(ABC):
34
- inner_switch = False
36
+ inner_switch = defaultdict(bool)
35
37
  hook_handle_dict = {}
36
38
  params_grad_info = {}
37
39
 
@@ -86,7 +88,7 @@ class BaseHookManager(ABC):
86
88
  grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
87
89
  # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
88
90
  setattr(module, 'params_grad_name', grad_name)
89
- # data_mode为forward时,不注册参数hook
91
+ # data_mode为forward时,不注册参数hook
90
92
  if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
91
93
  for param_name, param in params_dict.items():
92
94
  if param.requires_grad:
@@ -116,7 +118,7 @@ class BaseHookManager(ABC):
116
118
  # 记录当前模块的参数梯度信息已占位
117
119
  BaseHookManager.params_grad_info[grad_name] = True
118
120
 
119
- def _should_execute_hook(self, hook_type, module, is_forward):
121
+ def _should_execute_hook(self, hook_type, module, is_forward, tid):
120
122
  is_module_hook = hook_type == Const.MODULE
121
123
  if is_module_hook and not Runtime.is_running:
122
124
  return False
@@ -124,7 +126,7 @@ class BaseHookManager(ABC):
124
126
  return False
125
127
  elif not is_module_hook and not is_forward and not module.forward_data_collected:
126
128
  return False
127
- if BaseHookManager.inner_switch:
129
+ if BaseHookManager.inner_switch[tid]:
128
130
  return False
129
131
  if not self.data_collector or self.data_collector.data_processor.is_terminated:
130
132
  return False
@@ -132,111 +134,139 @@ class BaseHookManager(ABC):
132
134
 
133
135
  def _build_grad_hook(self, module, ori_name, param_name):
134
136
  def hook_fn(grad):
135
- if not self._should_execute_hook(Const.MODULE, module, False):
137
+ tid = threading.get_ident()
138
+ if not self._should_execute_hook(Const.MODULE, module, False, tid):
136
139
  return
137
- BaseHookManager.inner_switch = True
138
- self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
139
- BaseHookManager.inner_switch = False
140
+ with ThreadSafe():
141
+ BaseHookManager.inner_switch[tid] = True
142
+ self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
143
+ BaseHookManager.inner_switch[tid] = False
140
144
  return
145
+
141
146
  return hook_fn
142
147
 
143
148
  def _build_forward_pre_hook(self, hook_type, full_name, api_name):
144
149
  def forward_pre_hook(module, args, kwargs=None):
150
+ """
151
+ 为确保多线程场景下 L1 级别数据采集的正确性,每个封装后的 API 的 init 方法和 forward_pre_hook 需要确保在一个线程内连续完成,
152
+ 因此在 API 的 init 方法执行 ThreadSafe.acquire() 加锁操作,
153
+ 并且在 API 的 forward_pre_hook 方法执行 ThreadSafe.release() 释放锁操作。
154
+ """
145
155
  if hook_type == Const.MODULE:
146
156
  return
147
- if not self._should_execute_hook(hook_type, module, True):
157
+
158
+ tid = threading.get_ident()
159
+ if not self._should_execute_hook(hook_type, module, True, tid):
160
+ ThreadSafe.release()
148
161
  return
162
+
163
+ module.forward_data_collected = True
164
+ self._add_count(api_name)
165
+ if getattr(self.config, "online_run_ut", False):
166
+ ThreadSafe.release()
167
+ return
168
+
169
+ BaseHookManager.inner_switch[tid] = True
149
170
  if kwargs is None:
150
171
  kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
151
- with self._no_grad_context():
152
- BaseHookManager.inner_switch = False
153
- module.forward_data_collected = True
154
- self._add_count(api_name)
155
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
156
- self.data_collector.update_api_or_module_name(full_name)
157
- if getattr(self.config, "online_run_ut", False):
158
- BaseHookManager.inner_switch = False
159
- return
160
- self.data_collector.forward_input_data_collect(
161
- full_name,
162
- module,
163
- self._pid,
164
- module_input_output,
165
- self._is_recompute
166
- )
167
- BaseHookManager.inner_switch = False
168
- return forward_pre_hook
169
-
170
- def _build_forward_hook(self, hook_type, full_name):
171
- def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
172
- if not self._should_execute_hook(hook_type, module, True):
173
- self._clear_input_kwargs(module)
174
- return None
175
- kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
176
- BaseHookManager.inner_switch = True
177
- self.data_collector.update_api_or_module_name(full_name)
178
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
179
- with self._no_grad_context():
180
- if getattr(self.config, "online_run_ut", False):
181
- if self.data_collector.scope and not self.data_collector.scope.check(full_name):
182
- return None
183
- if self.attl_manager:
184
- self.attl_manager.attl_send(full_name, args, kwargs, output)
185
- BaseHookManager.inner_switch = False
186
- return None
187
- if hook_type == Const.MODULE:
188
- params_dict = self._get_params_dict(module)
189
- setattr(module_input_output, Const.PARAMS, params_dict)
190
- if params_dict:
191
- self._register_param_hook(full_name, module, params_dict)
172
+ try:
173
+ with self._no_grad_context():
174
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
192
175
  self.data_collector.update_api_or_module_name(full_name)
193
- self.data_collector.forward_data_collect(
194
- full_name,
195
- module,
196
- self._pid,
197
- module_input_output,
198
- self._is_recompute
199
- )
200
- self._init_params_grad_info(module, params_dict)
201
- else:
202
- self.data_collector.forward_output_data_collect(
176
+ self.data_collector.forward_input_data_collect(
203
177
  full_name,
204
178
  module,
205
179
  self._pid,
206
180
  module_input_output,
207
181
  self._is_recompute
208
182
  )
183
+ except Exception as e:
184
+ logger.error(f"The forward pre hook execution of the {full_name} API failed.")
185
+ raise e
186
+ finally:
187
+ BaseHookManager.inner_switch[tid] = False
188
+ ThreadSafe.release()
189
+
190
+ return forward_pre_hook
191
+
192
+ def _build_forward_hook(self, hook_type, full_name):
193
+ def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
194
+ tid = threading.get_ident()
195
+ if not self._should_execute_hook(hook_type, module, True, tid):
209
196
  self._clear_input_kwargs(module)
197
+ return None
210
198
 
211
- if self.data_collector.if_return_forward_new_output():
212
- forward_new_output = self.data_collector.get_forward_new_output()
213
- BaseHookManager.inner_switch = False
214
- return forward_new_output
199
+ with ThreadSafe():
200
+ kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
201
+ BaseHookManager.inner_switch[tid] = True
202
+ self.data_collector.update_api_or_module_name(full_name)
203
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
204
+ with self._no_grad_context():
205
+ if getattr(self.config, "online_run_ut", False):
206
+ if self.data_collector.scope and not self.data_collector.scope.check(full_name):
207
+ return None
208
+ if self.attl_manager:
209
+ self.attl_manager.attl_send(full_name, args, kwargs, output)
210
+ BaseHookManager.inner_switch[tid] = False
211
+ return None
212
+ if hook_type == Const.MODULE:
213
+ params_dict = self._get_params_dict(module)
214
+ setattr(module_input_output, Const.PARAMS, params_dict)
215
+ if params_dict:
216
+ self._register_param_hook(full_name, module, params_dict)
217
+ self.data_collector.update_api_or_module_name(full_name)
218
+ self.data_collector.forward_data_collect(
219
+ full_name,
220
+ module,
221
+ self._pid,
222
+ module_input_output,
223
+ self._is_recompute
224
+ )
225
+ self._init_params_grad_info(module, params_dict)
226
+ else:
227
+ self.data_collector.forward_output_data_collect(
228
+ full_name,
229
+ module,
230
+ self._pid,
231
+ module_input_output,
232
+ self._is_recompute
233
+ )
234
+ self._clear_input_kwargs(module)
235
+
236
+ if self.data_collector.if_return_forward_new_output():
237
+ forward_new_output = self.data_collector.get_forward_new_output()
238
+ BaseHookManager.inner_switch[tid] = False
239
+ return forward_new_output
240
+
241
+ BaseHookManager.inner_switch[tid] = False
242
+ return output
215
243
 
216
- BaseHookManager.inner_switch = False
217
- return output
218
244
  return forward_hook
219
245
 
220
246
  def _build_backward_hook(self, hook_type, full_name):
221
247
  def backward_hook(module, grad_input, grad_output):
222
- if not self._should_execute_hook(hook_type, module, False):
223
- return
224
- BaseHookManager.inner_switch = True
225
- self.data_collector.update_api_or_module_name(full_name)
226
- if getattr(self.config, "online_run_ut", False):
227
- BaseHookManager.inner_switch = False
248
+ tid = threading.get_ident()
249
+ if not self._should_execute_hook(hook_type, module, False, tid):
228
250
  return
229
- need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
230
- if need_exchange:
231
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
232
- else:
233
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
234
- self.data_collector.backward_data_collect(
251
+
252
+ with ThreadSafe():
253
+ BaseHookManager.inner_switch[tid] = True
254
+ self.data_collector.update_api_or_module_name(full_name)
255
+ if getattr(self.config, "online_run_ut", False):
256
+ BaseHookManager.inner_switch[tid] = False
257
+ return
258
+ need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
259
+ if need_exchange:
260
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
261
+ else:
262
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
263
+ self.data_collector.backward_data_collect(
235
264
  full_name,
236
265
  module,
237
266
  self._pid,
238
267
  module_input_output,
239
268
  self._is_recompute
240
269
  )
241
- BaseHookManager.inner_switch = False
270
+ BaseHookManager.inner_switch[tid] = False
271
+
242
272
  return backward_hook
@@ -0,0 +1,338 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from collections import namedtuple
16
+ from datetime import timezone, timedelta
17
+ from functools import wraps
18
+ from datetime import datetime
19
+ import os
20
+ import re
21
+
22
+ from msprobe.core.common.const import MonitorConst
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.utils import is_int
25
+ from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
26
+
27
+
28
+ beijing_tz = timezone(timedelta(hours=8))
29
+ MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio"))
30
+
31
+
32
+ class MsgConst:
33
+ """
34
+ Class for log messages const
35
+ """
36
+ SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
37
+
38
+
39
+ def get_output_base_dir():
40
+ return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
41
+
42
+
43
+ def filter_special_chars(func):
44
+ @wraps(func)
45
+ def func_level(msg):
46
+ for char in MsgConst.SPECIAL_CHAR:
47
+ msg = msg.replace(char, '_')
48
+ return func(msg)
49
+
50
+ return func_level
51
+
52
+
53
+ def validate_ops(ops):
54
+ if not isinstance(ops, list):
55
+ raise TypeError("ops should be a list")
56
+ valid_ops = []
57
+ for op in ops:
58
+ if op not in MonitorConst.OP_LIST:
59
+ logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
60
+ continue
61
+ valid_ops.append(op)
62
+ if not valid_ops:
63
+ default_op = MonitorConst.OP_LIST[0]
64
+ valid_ops.append(default_op)
65
+ logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
66
+ # 增加默认shape和dtype参数
67
+ if "shape" not in valid_ops:
68
+ valid_ops.append("shape")
69
+ if "dtype" not in valid_ops:
70
+ valid_ops.append("dtype")
71
+ return valid_ops
72
+
73
+
74
+ def validate_ndigits(ndigits):
75
+ if not ndigits:
76
+ return
77
+ if not is_int(ndigits) or ndigits <= 0:
78
+ raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
79
+ if ndigits > MonitorConst.MAX_NDIGITS:
80
+ raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
81
+
82
+
83
+ def validate_ranks(ranks):
84
+ if not isinstance(ranks, list):
85
+ raise TypeError("module_ranks should be a list")
86
+ for rank in ranks:
87
+ if not isinstance(rank, int) or isinstance(rank, bool):
88
+ raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
89
+
90
+
91
+ def validate_targets(targets):
92
+ if not isinstance(targets, dict):
93
+ raise TypeError('targets in config.json should be a dict')
94
+ for module_name, field in targets.items():
95
+ if not isinstance(module_name, str):
96
+ raise TypeError('key of targets should be module_name[str] in config.json')
97
+ if not isinstance(field, dict):
98
+ raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
99
+
100
+
101
+ def validate_print_struct(print_struct):
102
+ if not isinstance(print_struct, bool):
103
+ raise TypeError("print_struct should be a bool")
104
+
105
+
106
+ def validate_ur_distribution(ur_distribution):
107
+ if not isinstance(ur_distribution, bool):
108
+ raise TypeError('ur_distribution should be a bool')
109
+
110
+
111
+ def validate_xy_distribution(xy_distribution):
112
+ if not isinstance(xy_distribution, bool):
113
+ raise TypeError('xy_distribution should be a bool')
114
+
115
+
116
+ def validate_wg_distribution(wg_distribution):
117
+ if not isinstance(wg_distribution, bool):
118
+ raise TypeError('wg_distribution should be a bool')
119
+
120
+
121
+ def validate_mg_distribution(mg_distribution):
122
+ if not isinstance(mg_distribution, bool):
123
+ raise TypeError('mg_distribution should be a bool')
124
+
125
+
126
+ def validate_param_distribution(param_distribution):
127
+ if not isinstance(param_distribution, bool):
128
+ raise TypeError('param_distribution should be a bool')
129
+
130
+
131
+ def validate_cc_distribution(cc_distribution):
132
+ if not isinstance(cc_distribution, dict):
133
+ raise TypeError('cc_distribution should be a dictionary')
134
+ for key, value in cc_distribution.items():
135
+ if key == 'enable':
136
+ if not isinstance(value, bool):
137
+ raise TypeError('cc_distribution enable should be a bool')
138
+ elif key == 'cc_codeline':
139
+ if not isinstance(value, list):
140
+ raise TypeError('cc_distribution cc_codeline should be a list')
141
+ elif key == 'cc_pre_hook':
142
+ if not isinstance(value, bool):
143
+ raise TypeError('cc_distribution cc_pre_hook should be a bool')
144
+ elif key == 'cc_log_only':
145
+ if not isinstance(value, bool):
146
+ raise TypeError('cc_distribution cc_log_only should be a bool')
147
+ else:
148
+ raise TypeError(f'{key} of cc_distribution is not supported.')
149
+
150
+
151
+ def validate_squash_name(squash_name):
152
+ if not isinstance(squash_name, bool):
153
+ raise TypeError('squash_name should be a bool')
154
+
155
+
156
+ def validate_alert(alert):
157
+ if not isinstance(alert, dict):
158
+ raise TypeError('alert should be a dictionary')
159
+ rules = alert.get('rules')
160
+ if rules and isinstance(rules, list):
161
+ for rule in rules:
162
+ rule_name = rule.get("rule_name")
163
+ if rule_name and rule_name not in MonitorConst.RULE_NAME:
164
+ raise TypeError(f"{rule_name} is not supported")
165
+ args = rule.get("args")
166
+ if args and isinstance(args, dict):
167
+ threshold = args.get("threshold")
168
+ if not isinstance(threshold, (float, int)) or threshold < 0:
169
+ raise TypeError('threshold must be float and not less than 0')
170
+ dump = alert.get('dump')
171
+ if dump and not isinstance(dump, bool):
172
+ raise TypeError('dump must be bool.')
173
+
174
+
175
+ def validate_step_count_per_record(step_count_per_record):
176
+ if not is_int(step_count_per_record):
177
+ raise TypeError('step_count_per_record must be int.')
178
+ if step_count_per_record < 1:
179
+ raise ValueError("step_count_per_record must greater than 0")
180
+ if step_count_per_record > 1e6:
181
+ raise ValueError("step_count_per_record must smaller than 1e6")
182
+
183
+
184
+ def validate_dynamic_on(dynamic_on):
185
+ if not isinstance(dynamic_on, bool):
186
+ raise TypeError('dynamic_on should be a bool')
187
+
188
+
189
+ def validate_monitor_mbs_grad(monitor_mbs_grad):
190
+ if not isinstance(monitor_mbs_grad, bool):
191
+ logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
192
+ return False
193
+ return monitor_mbs_grad
194
+
195
+
196
+ def validate_append_output(append_output):
197
+ if not isinstance(append_output, list):
198
+ raise TypeError('append_output should be a list')
199
+ if len(append_output) > 0 and len(append_output) != 2:
200
+ raise ValueError('append_output should be empty or contain exactly 2 elements')
201
+
202
+
203
+ def validate_config(config):
204
+ config['ops'] = validate_ops(config.get('ops', []))
205
+
206
+ ndigits = config.get('ndigits')
207
+ validate_ndigits(ndigits)
208
+
209
+ eps = config.get('eps', 1e-8)
210
+ if not isinstance(eps, float):
211
+ raise TypeError("eps should be a float")
212
+
213
+ ranks = config.get("module_ranks", [])
214
+ validate_ranks(ranks)
215
+
216
+ targets = config.get("targets", {})
217
+ validate_targets(targets)
218
+
219
+ print_struct = config.get('print_struct', False)
220
+ validate_print_struct(print_struct)
221
+
222
+ ur_distribution = config.get('ur_distribution', False)
223
+ validate_ur_distribution(ur_distribution)
224
+
225
+ xy_distribution = config.get('xy_distribution', False)
226
+ validate_xy_distribution(xy_distribution)
227
+
228
+ wg_distribution = config.get('wg_distribution', False)
229
+ validate_wg_distribution(wg_distribution)
230
+
231
+ mg_distribution = config.get('mg_distribution', False)
232
+ validate_mg_distribution(mg_distribution)
233
+
234
+ param_distribution = config.get('param_distribution', False)
235
+ validate_param_distribution(param_distribution)
236
+
237
+ cc_distribution = config.get('cc_distribution', {})
238
+ validate_cc_distribution(cc_distribution)
239
+
240
+ alert = config.get('alert', {})
241
+ validate_alert(alert)
242
+
243
+ step_count_per_record = config.get('step_count_per_record', 1)
244
+ validate_step_count_per_record(step_count_per_record)
245
+
246
+ config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
247
+ MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
248
+ config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
249
+ MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
250
+ MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
251
+ config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
252
+ MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
253
+
254
+ squash_name = config.get('squash_name', True)
255
+ validate_squash_name(squash_name)
256
+
257
+ time_tags = config.get("append_output", [])
258
+ validate_append_output(time_tags)
259
+
260
+ config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
261
+
262
+ dynamic_on = config.get('dynamic_on', False)
263
+ validate_dynamic_on(dynamic_on)
264
+
265
+ if not targets:
266
+ if xy_distribution:
267
+ config["all_xy"] = True
268
+ config["targets"] = {"": {}}
269
+
270
+
271
+ def time_str2time_digit(time_str):
272
+ time_format = '%b%d_%H-%M-%S'
273
+ if not isinstance(time_str, str):
274
+ raise TypeError(f"time_str:{time_str} should be a str")
275
+ try:
276
+ time_digit = datetime.strptime(time_str, time_format)
277
+ except Exception as e:
278
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
279
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
280
+ return time_digit
281
+
282
+
283
+ def get_target_output_dir(monitor_path, time_start, time_end):
284
+ check_file_or_directory_path(monitor_path, isdir=True)
285
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
286
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
287
+ if time_start and time_end and time_start > time_end:
288
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
289
+ result = {}
290
+ for dirname in os.listdir(monitor_path):
291
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
292
+ if not match:
293
+ continue
294
+ time_tag = match.group(1)
295
+ rank = match.group(2)
296
+ target_time = time_str2time_digit(time_tag)
297
+ start_ok = time_start is None or target_time >= time_start
298
+ end_ok = time_end is None or target_time <= time_end
299
+ if start_ok and end_ok:
300
+ result[rank] = os.path.join(monitor_path, dirname)
301
+ return result
302
+
303
+
304
+ def chmod_tensorboard_dir(path):
305
+ """
306
+ format配置为tensorboard时,需要补充文件权限设置
307
+ """
308
+ try:
309
+ recursive_chmod(path)
310
+ except Exception as e:
311
+ logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
312
+
313
+
314
+ def validate_set_monitor(grad_acc_steps, start_iteration):
315
+ """
316
+ validate parameters of set_monitor.
317
+ """
318
+ grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
319
+ MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
320
+
321
+ start_iteration = validate_int_arg(start_iteration, "start_iteration",
322
+ MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
323
+ return grad_acc_steps, start_iteration
324
+
325
+
326
+ def validate_int_arg(value, name, minimum, default_value):
327
+ """Validate int args, if any exception occurs, use the default value."""
328
+ if value is None:
329
+ return default_value
330
+ try:
331
+ if not is_int(value):
332
+ raise TypeError(f"{name} must be int")
333
+ if value < minimum:
334
+ raise ValueError(f"{name} must greater than {minimum}")
335
+ except Exception as e:
336
+ value = default_value
337
+ logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
338
+ return value
msprobe/core/service.py CHANGED
@@ -303,7 +303,8 @@ class BaseService(ABC):
303
303
  self.logger.info(
304
304
  f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.")
305
305
  self.cur_token_id += 1
306
- if isinstance(root_model, list):
306
+ # 此处root_model可以保证为 Module/Cell类型 或 [Module/Cell]类型
307
+ if root_model and isinstance(root_model, list):
307
308
  root_model = root_model[0]
308
309
  self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
309
310
  if self._is_online_run_ut: