mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,94 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import abc
18
+ from mindspore import Tensor
19
+
20
+ from msprobe.core.common.log import logger
21
+
22
+
23
+ # 用于存储所有validator实现类的注册表
24
+ config_validator_registry = {}
25
+
26
+
27
+ def register_config_validator(cls):
28
+ """装饰器 用于注册ConfigValidator的实现类"""
29
+ config_validator_registry[cls.__name__] = cls
30
+ return cls
31
+
32
+
33
+ class ConfigValidator(metaclass=abc.ABCMeta):
34
+ @abc.abstractmethod
35
+ def check_pattern_match(self, config_spec: str):
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
40
+ pass
41
+
42
+
43
+ @register_config_validator
44
+ class TensorValidator(ConfigValidator):
45
+ def check_pattern_match(self, config_spec: str):
46
+ pattern = re.compile(r"tensor")
47
+ return pattern.match(config_spec)
48
+
49
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
50
+ if not isinstance(actual_data, Tensor):
51
+ raise ValueError(
52
+ f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
53
+
54
+
55
+ @register_config_validator
56
+ class TupleValidator(ConfigValidator):
57
+ def check_pattern_match(self, config_spec: str):
58
+ pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
59
+ return pattern.match(config_spec)
60
+
61
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
62
+ length, index = pattern_match.groups()
63
+ if index is None:
64
+ index = 0
65
+ length, index = int(length), int(index)
66
+
67
+ if not (0 <= index < length):
68
+ raise ValueError(
69
+ f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
70
+ f"y must be greater than or equal to 0 and less than x.")
71
+ if not isinstance(actual_data, tuple):
72
+ raise ValueError(
73
+ f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
74
+ if len(actual_data) != length:
75
+ raise ValueError(
76
+ f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
77
+ f"actual is {len(actual_data)} please check.")
78
+ return index
79
+
80
+
81
+ def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
82
+ focused_col = None
83
+ for _, validator_cls in config_validator_registry.items():
84
+ config_validator = validator_cls()
85
+ pattern_match = config_validator.check_pattern_match(config_spec)
86
+ if pattern_match:
87
+ try:
88
+ focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
89
+ except ValueError as e:
90
+ logger.warning(f"config spec validate failed: {str(e)}")
91
+ return focused_col
92
+ logger.warning(f"config spec in {module_name} {data_type} not supported, "
93
+ f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
94
+ return focused_col
@@ -0,0 +1,267 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from mindspore import dtype as mstype, Tensor
17
+
18
+ from msprobe.mindspore.monitor.features import FUNC_MAP
19
+ from msprobe.core.common.const import MonitorConst
20
+ from msprobe.core.common.utils import is_int
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def get_single_metrics(op_list, tag, tensor, output=None):
25
+ if output is None:
26
+ output = {}
27
+ if tag not in output:
28
+ output[tag] = {}
29
+ for op in op_list:
30
+ func = FUNC_MAP.get(op)
31
+ statistic = func(tensor)
32
+ if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16:
33
+ statistic = float(statistic)
34
+ statistic = Tensor(statistic)
35
+ output[tag][op] = statistic.astype(mstype.float32)
36
+
37
+
38
+ def get_metrics(op_list, tag2tensor, eps, output=None):
39
+ if output is None:
40
+ output = {}
41
+ for tag, tensor in tag2tensor.items():
42
+ if tag not in output:
43
+ output[tag] = {}
44
+ get_single_metrics(op_list, tag, tensor, output)
45
+ return output
46
+
47
+
48
+ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
49
+ if rank is None:
50
+ return f"{module_or_param_name}/{tag}"
51
+ else:
52
+ return f"{module_or_param_name}/rank{rank}/{tag}"
53
+
54
+
55
+ def step_accumulates_one(context, micro_batch_number):
56
+ """
57
+ :param context: ModuleHookContext
58
+ :param micro_batch_number: mbs of training model.
59
+ :return:
60
+ """
61
+ context.micro_step += 1
62
+ if context.micro_step == micro_batch_number:
63
+ context.micro_step = 0
64
+ context.step += 1
65
+
66
+
67
+ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8):
68
+ """
69
+ If current step less than start_step or not reach step_interval, skip current step.
70
+ :param step: current training step, int
71
+ :param start_step: int
72
+ :param step_interval: int
73
+ :return: whether skip or not, bool
74
+ """
75
+ return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
76
+
77
+
78
+ def validate_ops(ops):
79
+ if not isinstance(ops, list):
80
+ raise TypeError("ops should be a list")
81
+ valid_ops = []
82
+ for op in ops:
83
+ if op not in MonitorConst.OP_LIST:
84
+ logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
85
+ continue
86
+ valid_ops.append(op)
87
+ if not valid_ops:
88
+ default_op = MonitorConst.OP_LIST[0]
89
+ valid_ops.append(default_op)
90
+ logger.info(f"There is no valid ops, default op {default_op} is used")
91
+ return valid_ops
92
+
93
+
94
+ def validate_ranks(ranks):
95
+ if not isinstance(ranks, list):
96
+ raise TypeError("module_ranks should be a list")
97
+ for rank in ranks:
98
+ if not isinstance(rank, str):
99
+ raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
100
+
101
+
102
+ def validate_targets(targets):
103
+ if not isinstance(targets, dict):
104
+ raise TypeError('targets in config.json should be a dict')
105
+ for module_name, field in targets.items():
106
+ if not isinstance(module_name, str):
107
+ raise TypeError('key of targets should be module_name[str] in config.json')
108
+ if not isinstance(field, dict):
109
+ raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
110
+
111
+
112
+ def validate_print_struct(print_struct):
113
+ if not isinstance(print_struct, bool):
114
+ raise TypeError("print_struct should be a bool")
115
+
116
+
117
+ def validate_ur_distribution(ur_distribution):
118
+ if not isinstance(ur_distribution, bool):
119
+ raise TypeError('ur_distribution should be a bool')
120
+
121
+
122
+ def validate_xy_distribution(xy_distribution):
123
+ if not isinstance(xy_distribution, bool):
124
+ raise TypeError('xy_distribution should be a bool')
125
+
126
+
127
+ def validate_wg_distribution(wg_distribution):
128
+ if not isinstance(wg_distribution, bool):
129
+ raise TypeError('wg_distribution should be a bool')
130
+
131
+
132
+ def validate_mg_distribution(mg_distribution):
133
+ if not isinstance(mg_distribution, bool):
134
+ raise TypeError('mg_distribution should be a bool')
135
+
136
+
137
+ def validate_param_distribution(param_distribution):
138
+ if not isinstance(param_distribution, bool):
139
+ raise TypeError('param_distribution should be a bool')
140
+
141
+
142
+ def validate_cc_distribution(cc_distribution):
143
+ if not isinstance(cc_distribution, dict):
144
+ raise TypeError('cc_distribution should be a dictionary')
145
+ expected_keys = {
146
+ 'enable': bool,
147
+ 'cc_codeline': list,
148
+ 'cc_pre_hook': bool,
149
+ 'cc_log_only': bool
150
+ }
151
+ for key, value in cc_distribution.items():
152
+ if key in expected_keys:
153
+ if not isinstance(value, expected_keys[key]):
154
+ raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
155
+ else:
156
+ raise TypeError(f'{key} of cc_distribution is not supported.')
157
+
158
+
159
+ def validate_alert(alert):
160
+ if not isinstance(alert, dict):
161
+ raise TypeError('alert should be a dictionary')
162
+ rules = alert.get('rules')
163
+ if rules and isinstance(rules, list):
164
+ for rule in rules:
165
+ rule_name = rule.get("rule_name")
166
+ if rule_name and rule_name not in MonitorConst.RULE_NAME:
167
+ raise TypeError(f"{rule_name} is not supported")
168
+ args = rule.get("args")
169
+ if args and isinstance(args, dict):
170
+ threshold = args.get("threshold")
171
+ if not isinstance(threshold, float) or threshold < 0:
172
+ raise TypeError('threshold must be float and not less than 0')
173
+ dump = alert.get('dump')
174
+ if dump and not isinstance(dump, bool):
175
+ raise TypeError('dump must be bool.')
176
+
177
+
178
+ def validate_step_count_per_record(step_count_per_record):
179
+ if not is_int(step_count_per_record):
180
+ raise TypeError('step_count_per_record must be int.')
181
+ if step_count_per_record < 1:
182
+ raise ValueError("step_count_per_record must greater than 0")
183
+ if step_count_per_record > 1e6:
184
+ raise ValueError("step_count_per_record must smaller than 1e6")
185
+
186
+
187
+ def validate_start_step(start_step):
188
+ if not is_int(start_step):
189
+ raise TypeError('start_step must be int.')
190
+ if start_step < 0:
191
+ raise ValueError("start_step must greater than 0")
192
+ if start_step > 1e8:
193
+ raise ValueError("start_step must smaller than 1e8")
194
+
195
+
196
+ def validate_step_interval(step_interval):
197
+ if not is_int(step_interval):
198
+ raise TypeError('step_interval must be int.')
199
+ if step_interval < 1:
200
+ raise ValueError("step_interval must greater than 1")
201
+ if step_interval > 1e8:
202
+ raise ValueError("step_interval must smaller than 1e8")
203
+
204
+
205
+ def validate_collect_times(collect_times):
206
+ if not is_int(collect_times):
207
+ raise TypeError('collect_times must be int.')
208
+ if collect_times < 1:
209
+ raise ValueError("collect_times must greater than 1")
210
+
211
+
212
+ def validate_config(config):
213
+ config['ops'] = validate_ops(config.get('ops', []))
214
+
215
+ eps = config.get('eps', 1e-8)
216
+ if not isinstance(eps, float):
217
+ raise TypeError("eps should be a float")
218
+
219
+ ranks = config.get("module_ranks", [])
220
+ validate_ranks(ranks)
221
+
222
+ targets = config.get("targets", {})
223
+ validate_targets(targets)
224
+
225
+ print_struct = config.get('print_struct', False)
226
+ validate_print_struct(print_struct)
227
+
228
+ ur_distribution = config.get('ur_distribution', False)
229
+ validate_ur_distribution(ur_distribution)
230
+
231
+ xy_distribution = config.get('xy_distribution', False)
232
+ validate_xy_distribution(xy_distribution)
233
+
234
+ wg_distribution = config.get('wg_distribution', False)
235
+ validate_wg_distribution(wg_distribution)
236
+
237
+ mg_distribution = config.get('mg_distribution', False)
238
+ validate_mg_distribution(mg_distribution)
239
+
240
+ param_distribution = config.get('param_distribution', False)
241
+ validate_param_distribution(param_distribution)
242
+
243
+ cc_distribution = config.get('cc_distribution', {})
244
+ validate_cc_distribution(cc_distribution)
245
+
246
+ alert = config.get('alert', {})
247
+ validate_alert(alert)
248
+
249
+ step_count_per_record = config.get('step_count_per_record', 1)
250
+ validate_step_count_per_record(step_count_per_record)
251
+
252
+ start_step = config.get('start_step', 0)
253
+ validate_start_step(start_step)
254
+
255
+ step_interval = config.get('step_interval', 1)
256
+ validate_step_interval(step_interval)
257
+
258
+ collect_times = config.get('collect_times', 1e8)
259
+ validate_collect_times(collect_times)
260
+
261
+ if not targets:
262
+ if xy_distribution:
263
+ config["all_xy"] = True
264
+ config["targets"] = {"": {}}
265
+ config["is_select"] = False
266
+ else:
267
+ config["is_select"] = True
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -106,12 +106,18 @@ class GradProbeConfig(BaseConfig):
106
106
  check_numeral_list_ascend(self.bounds)
107
107
 
108
108
 
109
+ class StructureConfig(BaseConfig):
110
+ def __init__(self, json_config):
111
+ super().__init__(json_config)
112
+
113
+
109
114
  TaskDict = {
110
115
  Const.TENSOR: TensorConfig,
111
116
  Const.STATISTICS: StatisticsConfig,
112
117
  Const.OVERFLOW_CHECK: OverflowCheckConfig,
113
118
  Const.FREE_BENCHMARK: FreeBenchmarkConfig,
114
- Const.GRAD_PROBE: GradProbeConfig
119
+ Const.GRAD_PROBE: GradProbeConfig,
120
+ Const.STRUCTURE: StructureConfig
115
121
  }
116
122
 
117
123
 
@@ -22,6 +22,7 @@ import mindspore as ms
22
22
  from mindspore import nn
23
23
  from mindspore.common.api import _no_grad
24
24
  from mindspore.ops.primitive import Primitive
25
+
25
26
  try:
26
27
  from mindspore.common._pijit_context import PIJitCaptureContext
27
28
  except ImportError:
@@ -31,7 +32,7 @@ else:
31
32
 
32
33
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
33
34
  from msprobe.core.common.file_utils import create_directory
34
- from msprobe.core.common.utils import Const, print_tools_ends_info
35
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
35
36
  from msprobe.core.data_dump.data_collector import build_data_collector
36
37
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
37
38
  ModuleBackwardInputs)
@@ -68,8 +69,10 @@ class Service:
68
69
  self.start_call = False
69
70
  self.should_stop_service = False
70
71
  self.params_grad_info = {}
72
+ self.hook_handle_dict = {}
71
73
  # 提前注册,确保注册尽可能多的API hook
72
74
  self.register_api_hook()
75
+ self.init_for_debug_level()
73
76
 
74
77
  @staticmethod
75
78
  def check_model_valid(models):
@@ -138,7 +141,12 @@ class Service:
138
141
  if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
139
142
  for param_name, param in params_dict.items():
140
143
  if param.requires_grad:
141
- param.register_hook(grad_hook(cell, ori_name, param_name))
144
+ name = ori_name + Const.SEP + param_name
145
+ old_handle = self.hook_handle_dict.get(name)
146
+ if old_handle and hasattr(old_handle, "remove"):
147
+ old_handle.remove()
148
+ handle = param.register_hook(grad_hook(cell, ori_name, param_name))
149
+ self.hook_handle_dict[name] = handle
142
150
 
143
151
  def init_params_grad_info(cell, params_dict):
144
152
  '''
@@ -168,11 +176,15 @@ class Service:
168
176
  module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
169
177
  if target_type == BaseScope.Module_Type_Module:
170
178
  api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
171
- params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
172
- recurse=False).items()}
173
- setattr(module_input_output, Const.PARAMS, params_dict)
179
+ params_dict = {}
180
+ if self.config.task != Const.STRUCTURE:
181
+ params_dict = {
182
+ key.split(Const.SEP)[-1]: value
183
+ for key, value in cell.parameters_dict(recurse=False).items()
184
+ }
185
+ setattr(module_input_output, Const.PARAMS, params_dict)
174
186
  # 判断是否需要注册参数hook
175
- if not hasattr(cell, 'params_grad_name') and params_dict:
187
+ if params_dict:
176
188
  ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
177
189
  grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
178
190
  # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
@@ -257,15 +269,20 @@ class Service:
257
269
  self.primitive_counters[primitive_name] += 1
258
270
 
259
271
  def step(self):
272
+ if self.config.level == Const.LEVEL_DEBUG:
273
+ return
260
274
  if self.config.async_dump:
261
275
  self.data_collector.fill_stack_tensor_data()
262
- self.data_collector.data_processor.dump_async_data()
276
+ if self.config.task == Const.TENSOR:
277
+ self.data_collector.data_processor.dump_async_data()
263
278
  self.data_collector.write_json()
264
279
  self.current_iter += 1
265
280
  self.data_collector.update_iter(self.current_iter)
266
281
  self.reset_status()
267
282
 
268
283
  def start(self, model=None):
284
+ if self.config.level == Const.LEVEL_DEBUG:
285
+ return
269
286
  self.start_call = True
270
287
  if self.should_stop_service:
271
288
  return
@@ -294,7 +311,10 @@ class Service:
294
311
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
295
312
  JitDump.set_config(self.config)
296
313
  JitDump.set_data_collector(self.data_collector)
297
- ms.common.api._MindsporeFunctionExecutor = JitDump
314
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
315
+ ms.common.api._MindsporeFunctionExecutor = JitDump
316
+ else:
317
+ ms.common.api._JitExecutor = JitDump
298
318
  ms.common.api._PyNativeExecutor.grad = JitDump.grad
299
319
  if pijit_label:
300
320
  PIJitCaptureContext.__enter__ = self.empty
@@ -310,6 +330,8 @@ class Service:
310
330
  JitDump.jit_dump_switch = True
311
331
 
312
332
  def stop(self):
333
+ if self.config.level == Const.LEVEL_DEBUG:
334
+ return
313
335
  if self.should_stop_service:
314
336
  return
315
337
  logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
@@ -326,7 +348,8 @@ class Service:
326
348
  self.start_call = False
327
349
  if self.config.async_dump:
328
350
  self.data_collector.fill_stack_tensor_data()
329
- self.data_collector.data_processor.dump_async_data()
351
+ if self.config.task == Const.TENSOR:
352
+ self.data_collector.data_processor.dump_async_data()
330
353
  self.data_collector.write_json()
331
354
  JitDump.jit_dump_switch = False
332
355
 
@@ -370,12 +393,13 @@ class Service:
370
393
  else:
371
394
  dump_data_dir = None
372
395
 
373
- dump_file_path = os.path.join(dump_dir, "dump.json")
374
- stack_file_path = os.path.join(dump_dir, "stack.json")
375
- construct_file_path = os.path.join(dump_dir, "construct.json")
376
- self.data_collector.update_dump_paths(
377
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
378
- )
396
+ dump_path_aggregation = DumpPathAggregation()
397
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
398
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
399
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
400
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
401
+ self.data_collector.update_dump_paths(dump_path_aggregation)
402
+
379
403
  self.data_collector.initialize_json_file(
380
404
  framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
381
405
  )
@@ -394,13 +418,13 @@ class Service:
394
418
 
395
419
  def get_cell_or_module(model):
396
420
  return model.named_modules() if is_mindtorch() else model.cells_and_names()
397
-
421
+
398
422
  if isinstance(self.model, (list, tuple)):
399
423
  for index, model in enumerate(self.model):
400
424
  cells_and_names_with_index[str(index)] = get_cell_or_module(model)
401
425
  else:
402
426
  cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
403
- return cells_and_names_with_index
427
+ return cells_and_names_with_index
404
428
 
405
429
  def register_primitive_hook(self):
406
430
  if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
@@ -430,7 +454,7 @@ class Service:
430
454
  if not self.model:
431
455
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
432
456
  f"The current level is {self.config.level}, the model cannot be None")
433
- model_type = Const.MODULE if is_mindtorch() else Const.CELL
457
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
434
458
  cells_and_names_with_index = self.get_cells_and_names()
435
459
 
436
460
  for index, cells_and_names in cells_and_names_with_index.items():
@@ -439,7 +463,7 @@ class Service:
439
463
  if cell == model:
440
464
  continue
441
465
  cell_index = (index + Const.SEP) if index != "-1" else ""
442
- prefix = (model_type + Const.SEP + cell_index + name +
466
+ prefix = (model_type + Const.SEP + cell_index + name +
443
467
  Const.SEP + cell.__class__.__name__ + Const.SEP)
444
468
  _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
445
469
  cell.register_forward_hook(forward_hook)
@@ -456,10 +480,9 @@ class Service:
456
480
 
457
481
  def reset_status(self):
458
482
  self.primitive_hook_service.primitive_counters.clear()
459
- self.data_collector.data_writer.reset_cache()
483
+ self.data_collector.reset_status()
460
484
  JitDump.jit_count = defaultdict(int)
461
485
  self.params_grad_info.clear()
462
-
463
486
  if self.config.level == Const.LEVEL_L2:
464
487
  self.data_collector.data_processor.reset_status()
465
488
  return
@@ -467,3 +490,54 @@ class Service:
467
490
  return
468
491
  if self.config.rank and self.current_rank not in self.config.rank:
469
492
  return
493
+
494
+ def init_for_debug_level(self):
495
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
496
+ return
497
+ try:
498
+ self.current_rank = get_rank_if_initialized()
499
+ except DistributedNotInitializedError:
500
+ self.current_rank = None
501
+ # dir: dump_path -- rank{} -- debug.json
502
+ self.dump_iter_dir = self.config.dump_path
503
+ cur_rank = self.current_rank if self.current_rank is not None else ''
504
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
505
+ create_directory(dump_dir)
506
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
507
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
508
+ create_directory(dump_data_dir)
509
+ else:
510
+ dump_data_dir = None
511
+
512
+ dump_path_aggregation = DumpPathAggregation()
513
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
514
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
515
+ self.data_collector.update_dump_paths(dump_path_aggregation)
516
+ self.data_collector.initialize_json_file(
517
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
518
+ )
519
+ self.debug_variable_counter = defaultdict(int)
520
+
521
+ def save(self, variable, name, save_backward):
522
+ '''
523
+ Args:
524
+ variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
525
+ name: str
526
+ save_backward: boolean
527
+ Return:
528
+ void
529
+ '''
530
+ if self.config.level != Const.LEVEL_DEBUG:
531
+ return
532
+ count = self.debug_variable_counter[name]
533
+ self.debug_variable_counter[name] += 1
534
+
535
+ name_with_count = f"{name}.{count}"
536
+ grad_name_with_count = f"{name}_grad.{count}"
537
+
538
+ # forward save
539
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
540
+
541
+ # backward save
542
+ if save_backward:
543
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
-
17
16
  import torch
18
17
  from .compare.distributed_compare import compare_distributed
19
18
  from .compare.pt_compare import compare
@@ -399,7 +399,7 @@ class OperatorScriptGenerator:
399
399
  def generate_kwargs_dict(self, kwargs_info, flag_device):
400
400
  kwargs_dict_generator = ""
401
401
  for key, value in kwargs_info.items():
402
- kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP
402
+ kwargs_dict_generator += '"' + key + '"' + MonitorConst.NAME_SEP
403
403
  if flag_device:
404
404
  kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
405
405
  else: