mindstudio-probe 8.1.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +46 -47
  3. msprobe/core/common/const.py +1 -0
  4. msprobe/core/common/file_utils.py +36 -18
  5. msprobe/core/common/utils.py +19 -8
  6. msprobe/core/compare/acc_compare.py +14 -5
  7. msprobe/core/compare/utils.py +7 -1
  8. msprobe/core/data_dump/data_collector.py +144 -90
  9. msprobe/core/data_dump/json_writer.py +31 -1
  10. msprobe/core/debugger/precision_debugger.py +19 -18
  11. msprobe/core/service.py +1 -0
  12. msprobe/core/single_save/single_comparator.py +25 -25
  13. msprobe/core/single_save/single_saver.py +5 -16
  14. msprobe/docs/01.installation.md +1 -0
  15. msprobe/docs/05.data_dump_PyTorch.md +3 -0
  16. msprobe/docs/06.data_dump_MindSpore.md +3 -0
  17. msprobe/docs/08.accuracy_checker_online_PyTorch.md +2 -2
  18. msprobe/docs/25.tool_function_introduction.md +19 -19
  19. msprobe/docs/33.generate_operator_MindSpore.md +10 -19
  20. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -0
  21. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  22. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +10 -1
  23. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  24. msprobe/mindspore/common/utils.py +1 -0
  25. msprobe/mindspore/debugger/precision_debugger.py +4 -4
  26. msprobe/mindspore/dump/cell_dump_process.py +13 -38
  27. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +1 -26
  28. msprobe/mindspore/dump/hook_cell/api_register.py +3 -3
  29. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +4 -4
  30. msprobe/mindspore/mindspore_service.py +3 -0
  31. msprobe/mindspore/monitor/features.py +10 -9
  32. msprobe/mindspore/monitor/optimizer_collect.py +4 -1
  33. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  34. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +7 -7
  35. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -0
  36. msprobe/pytorch/common/utils.py +1 -1
  37. msprobe/pytorch/debugger/precision_debugger.py +28 -25
  38. msprobe/pytorch/hook_module/api_register.py +3 -3
  39. msprobe/pytorch/monitor/optimizer_collect.py +4 -1
  40. msprobe/pytorch/pytorch_service.py +3 -0
  41. msprobe/visualization/compare/mode_adapter.py +9 -0
  42. msprobe/visualization/utils.py +3 -0
  43. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +0 -9
  44. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  45. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  46. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  47. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import atexit
17
17
  import os
18
+ import traceback
18
19
 
19
20
  from msprobe.core.data_dump.scope import ScopeFactory
20
21
  from msprobe.core.data_dump.json_writer import DataWriter
@@ -99,100 +100,150 @@ class DataCollector:
99
100
  self.data_writer.update_stack(name, stack_info)
100
101
 
101
102
  def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
102
- if self.config.task == Const.FREE_BENCHMARK:
103
- backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
104
- if self.check_scope_and_pid(self.scope, backward_name, pid):
105
- self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
106
- return
107
-
108
- if not self.check_scope_and_pid(self.scope, name, pid):
109
- return
110
-
111
- data_info = {}
112
- if self.config.task != Const.STRUCTURE:
113
- data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
114
- self.set_is_recomputable(data_info, is_recompute)
115
- if self.config.level == Const.LEVEL_L2:
116
- return
117
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
103
+ try:
104
+
105
+ if self.config.task == Const.FREE_BENCHMARK:
106
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
107
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
108
+ self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
109
+ return
110
+
111
+ if not self.check_scope_and_pid(self.scope, name, pid):
112
+ return
113
+
114
+ data_info = {}
115
+ if self.config.task != Const.STRUCTURE:
116
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
117
+ self.set_is_recomputable(data_info, is_recompute)
118
+ if self.config.level == Const.LEVEL_L2:
119
+ return
120
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
121
+
122
+ except Exception:
123
+ tb = traceback.format_exc()
124
+ self.data_writer.write_error_log(
125
+ f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
126
+ )
118
127
 
119
128
  def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
120
- self.update_construct(name)
121
- if not self.check_scope_and_pid(self.scope, name, pid):
122
- return
123
-
124
- data_info = {}
125
- if self.config.task != Const.STRUCTURE:
126
- data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
127
- self.set_is_recomputable(data_info, is_recompute)
128
- if self.config.level == Const.LEVEL_L2:
129
- return
130
- self.call_stack_collect(name)
131
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
129
+ try:
130
+
131
+ self.update_construct(name)
132
+ if not self.check_scope_and_pid(self.scope, name, pid):
133
+ return
134
+
135
+ data_info = {}
136
+ if self.config.task != Const.STRUCTURE:
137
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
138
+ self.set_is_recomputable(data_info, is_recompute)
139
+ if self.config.level == Const.LEVEL_L2:
140
+ return
141
+ self.call_stack_collect(name)
142
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
143
+
144
+ except Exception:
145
+ tb = traceback.format_exc()
146
+ self.data_writer.write_error_log(
147
+ f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
148
+ )
132
149
 
133
150
  def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
134
- if not self.check_scope_and_pid(self.scope, name, pid):
135
- return
136
-
137
- self.data_processor.analyze_forward(name, module, module_input_output)
151
+ try:
152
+ if not self.check_scope_and_pid(self.scope, name, pid):
153
+ return
154
+ self.data_processor.analyze_forward(name, module, module_input_output)
138
155
 
156
+ except Exception:
157
+ tb = traceback.format_exc()
158
+ self.data_writer.write_error_log(
159
+ f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
160
+ )
139
161
 
140
162
  def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
141
- self.update_construct(name)
142
- if not self.check_scope_and_pid(self.scope, name, pid):
143
- return
144
-
145
- data_info = {}
146
- if self.config.task != Const.STRUCTURE:
147
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
148
- self.set_is_recomputable(data_info, is_recompute)
149
- self.call_stack_collect(name)
150
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
163
+ try:
164
+
165
+ self.update_construct(name)
166
+ if not self.check_scope_and_pid(self.scope, name, pid):
167
+ return
168
+ data_info = {}
169
+ if self.config.task != Const.STRUCTURE:
170
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
171
+ self.set_is_recomputable(data_info, is_recompute)
172
+ self.call_stack_collect(name)
173
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
174
+
175
+ except Exception:
176
+ tb = traceback.format_exc()
177
+ self.data_writer.write_error_log(
178
+ f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}"
179
+ )
151
180
 
152
181
  def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
153
- if not self.check_scope_and_pid(self.scope, name, pid):
154
- return
182
+ try:
183
+ if not self.check_scope_and_pid(self.scope, name, pid):
184
+ return
185
+ self.data_processor.analyze_backward(name, module, module_input_output)
155
186
 
156
- self.data_processor.analyze_backward(name, module, module_input_output)
187
+ except Exception:
188
+ tb = traceback.format_exc()
189
+ self.data_writer.write_error_log(
190
+ f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
191
+ )
157
192
 
158
193
  def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
159
- self.update_construct(name)
160
- if not self.check_scope_and_pid(self.scope, name, pid):
161
- return
162
-
163
- data_info = {}
164
- if self.config.task != Const.STRUCTURE:
165
- data_info = self.data_processor.analyze_backward(name, module, module_input_output)
166
- if self.config.level == Const.LEVEL_L2:
167
- return
168
- # 获取执行反向的模块名称
169
- if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
170
- module_name = name.rsplit(Const.SEP, 2)[0]
171
- # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
172
- self.backward_module_names[module_name] = True
173
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
194
+ try:
195
+ self.update_construct(name)
196
+ if not self.check_scope_and_pid(self.scope, name, pid):
197
+ return
198
+ data_info = {}
199
+ if self.config.task != Const.STRUCTURE:
200
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
201
+ if self.config.level == Const.LEVEL_L2:
202
+ return
203
+ if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
204
+ module_name = name.rsplit(Const.SEP, 2)[0]
205
+ self.backward_module_names[module_name] = True
206
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
207
+
208
+ except Exception:
209
+ tb = traceback.format_exc()
210
+ self.data_writer.write_error_log(
211
+ f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}"
212
+ )
174
213
 
175
214
  def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
176
- self.update_construct(name)
177
- if not self.check_scope_and_pid(self.scope, name, pid):
178
- return
179
-
180
- data_info = {}
181
- if self.config.task != Const.STRUCTURE:
182
- data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
183
- self.set_is_recomputable(data_info, is_recompute)
184
- self.handle_data(name, data_info)
215
+ try:
216
+ self.update_construct(name)
217
+ if not self.check_scope_and_pid(self.scope, name, pid):
218
+ return
219
+ data_info = {}
220
+ if self.config.task != Const.STRUCTURE:
221
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
222
+ self.set_is_recomputable(data_info, is_recompute)
223
+ self.handle_data(name, data_info)
224
+
225
+ except Exception:
226
+ tb = traceback.format_exc()
227
+ self.data_writer.write_error_log(
228
+ f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
229
+ )
185
230
 
186
231
  def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
187
- self.update_construct(name)
188
- if not self.check_scope_and_pid(self.scope, name, pid):
189
- return
190
-
191
- data_info = {}
192
- if self.config.task != Const.STRUCTURE:
193
- data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
194
- self.set_is_recomputable(data_info, is_recompute)
195
- self.handle_data(name, data_info)
232
+ try:
233
+ self.update_construct(name)
234
+ if not self.check_scope_and_pid(self.scope, name, pid):
235
+ return
236
+ data_info = {}
237
+ if self.config.task != Const.STRUCTURE:
238
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
239
+ self.set_is_recomputable(data_info, is_recompute)
240
+ self.handle_data(name, data_info)
241
+
242
+ except Exception:
243
+ tb = traceback.format_exc()
244
+ self.data_writer.write_error_log(
245
+ f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
246
+ )
196
247
 
197
248
  def update_construct(self, name):
198
249
  if self.config.level not in DataCollector.level_without_construct:
@@ -228,20 +279,23 @@ class DataCollector:
228
279
  self.data_processor.update_iter(current_iter)
229
280
 
230
281
  def params_data_collect(self, name, param_name, pid, data):
231
- grad_name = name + Const.SEP + Const.PARAMS_GRAD
232
- self.update_api_or_module_name(grad_name)
233
- # 校验scope和pid,以及当前name是否有过反向计算
234
- if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
235
- # 如果没有反向计算,则需要清除之前占位写入的grad数据
236
- if self.data_writer.cache_data.get("data"):
237
- self.data_writer.cache_data.get("data").pop(grad_name, None)
238
- return
239
- data_info = self.data_processor.analyze_params(grad_name, param_name, data)
240
- self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
241
-
282
+ try:
283
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
284
+ self.update_api_or_module_name(grad_name)
285
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
286
+ if self.data_writer.cache_data.get("data"):
287
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
288
+ return
289
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
290
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
291
+ except Exception:
292
+ tb = traceback.format_exc()
293
+ self.data_writer.write_error_log(
294
+ f"[ERROR] params_data_collect failed: "
295
+ f"name={name}, param_name={param_name}, pid={pid}\n{tb}"
296
+ )
242
297
 
243
298
  def debug_data_collect_forward(self, variable, name_with_count):
244
-
245
299
  data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
246
300
  name_with_count_category = name_with_count + Const.SEP + Const.DEBUG
247
301
  self.data_writer.update_debug({name_with_count_category: data_info})
@@ -17,9 +17,11 @@ import csv
17
17
  import os
18
18
  import copy
19
19
  import threading
20
+ import traceback
21
+ from datetime import datetime, timezone, timedelta
20
22
 
21
23
  from msprobe.core.common.const import Const, FileCheckConst
22
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
24
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
23
25
  from msprobe.core.common.log import logger
24
26
  from msprobe.core.common.decorator import recursion_depth_decorator
25
27
 
@@ -35,6 +37,7 @@ class DataWriter:
35
37
  self.free_benchmark_file_path = None
36
38
  self.dump_tensor_data_dir = None
37
39
  self.debug_file_path = None
40
+ self.dump_error_info_path = None
38
41
  self.flush_size = 1000
39
42
  self.larger_flush_size = 20000
40
43
  self.cache_data = {}
@@ -42,6 +45,7 @@ class DataWriter:
42
45
  self.cache_construct = {}
43
46
  self.cache_debug = {}
44
47
  self.stat_stack_list = []
48
+ self._error_log_initialized = False
45
49
 
46
50
  @staticmethod
47
51
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -128,6 +132,7 @@ class DataWriter:
128
132
  self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
129
133
  self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
130
134
  self.debug_file_path = dump_path_aggregation.debug_file_path
135
+ self.dump_error_info_path = dump_path_aggregation.dump_error_info_path
131
136
 
132
137
  def flush_data_periodically(self):
133
138
  dump_data = self.cache_data.get(Const.DATA)
@@ -142,6 +147,31 @@ class DataWriter:
142
147
  if length % threshold == 0:
143
148
  self.write_json()
144
149
 
150
+ def write_error_log(self, message: str):
151
+ """
152
+ 写错误日志:
153
+ - 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加
154
+ - 添加时间戳
155
+ - 在 message 后写入当前的调用栈(方便追踪日志来源)
156
+ """
157
+ try:
158
+ mode = "w" if not self._error_log_initialized else "a"
159
+ self._error_log_initialized = True
160
+
161
+ check_path_before_create(self.dump_error_info_path)
162
+
163
+ with FileOpen(self.dump_error_info_path, mode) as f:
164
+ cst_timezone = timezone(timedelta(hours=8), name="CST")
165
+ timestamp = datetime.now(cst_timezone).strftime("%Y-%m-%d %H:%M:%S %z")
166
+ f.write(f"[{timestamp}] {message}\n")
167
+ f.write("Call stack (most recent call last):\n")
168
+
169
+ f.write("".join(traceback.format_stack()[:-1])) # 去掉自己这一层
170
+ f.write("\n")
171
+ except Exception as e:
172
+ # 如果连写日志都失败了,就打印到 stderr
173
+ logger.warning(f"[FallbackError] Failed to write error log: {e}")
174
+
145
175
  def update_data(self, new_data):
146
176
  with lock:
147
177
  if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
@@ -12,6 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  import os
16
17
 
17
18
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
@@ -46,18 +47,14 @@ class BasePrecisionDebugger:
46
47
  if self.initialized:
47
48
  return
48
49
  self.initialized = True
49
- self.check_input_params(config_path, task, dump_path, level)
50
- self.common_config, self.task_config = self.parse_config_path(config_path, task)
50
+ self._check_input_params(config_path, task, dump_path, level)
51
+ self.common_config, self.task_config = self._parse_config_path(config_path, task)
51
52
  self.task = self.common_config.task
52
53
  if step is not None:
53
54
  self.common_config.step = get_real_step_or_rank(step, Const.STEP)
54
55
 
55
56
  @staticmethod
56
- def get_task_config(task, json_config):
57
- raise NotImplementedError("Subclass must implment get_task_config")
58
-
59
- @staticmethod
60
- def check_input_params(config_path, task, dump_path, level):
57
+ def _check_input_params(config_path, task, dump_path, level):
61
58
  if not config_path:
62
59
  config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
63
60
  if config_path is not None:
@@ -81,14 +78,9 @@ class BasePrecisionDebugger:
81
78
  raise MsprobeException(
82
79
  MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
83
80
 
84
- @classmethod
85
- def get_instance(cls):
86
- instance = cls._instance
87
- if not instance:
88
- raise Exception(MsgConst.NOT_CREATED_INSTANCE)
89
- if instance.task in BasePrecisionDebugger.tasks_not_need_debugger:
90
- instance = None
91
- return instance
81
+ @staticmethod
82
+ def _get_task_config(task, json_config):
83
+ raise NotImplementedError("Subclass must implement _get_task_config")
92
84
 
93
85
  @classmethod
94
86
  def forward_backward_dump_end(cls):
@@ -129,15 +121,24 @@ class BasePrecisionDebugger:
129
121
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
130
122
  instance.service.restore_custom_api(module, api)
131
123
 
132
- def parse_config_path(self, json_file_path, task):
124
+ @classmethod
125
+ def _get_instance(cls):
126
+ instance = cls._instance
127
+ if not instance:
128
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
129
+ if instance.task in BasePrecisionDebugger.tasks_not_need_debugger:
130
+ instance = None
131
+ return instance
132
+
133
+ def _parse_config_path(self, json_file_path, task):
133
134
  if not json_file_path:
134
135
  json_file_path = os.path.join(os.path.dirname(__file__), "../../config.json")
135
136
  json_config = load_json(json_file_path)
136
137
  common_config = CommonConfig(json_config)
137
138
  if task:
138
- task_config = self.get_task_config(task, json_config)
139
+ task_config = self._get_task_config(task, json_config)
139
140
  else:
140
141
  if not common_config.task:
141
142
  common_config.task = Const.STATISTICS
142
- task_config = self.get_task_config(common_config.task, json_config)
143
+ task_config = self._get_task_config(common_config.task, json_config)
143
144
  return common_config, task_config
msprobe/core/service.py CHANGED
@@ -331,6 +331,7 @@ class BaseService(ABC):
331
331
  dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
332
332
  dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
333
333
  dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
334
+ dump_path_aggregation.dump_error_info_path = os.path.join(dump_dir, "dump_error_info.log")
334
335
  dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
335
336
  dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
336
337
  dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv")
@@ -181,32 +181,32 @@ class SingleComparator:
181
181
 
182
182
  @classmethod
183
183
  def compare_single_tag(cls, tag, array_paths1, array_paths2, output_dir):
184
- try:
185
- data = []
186
- paths1 = array_paths1.get(tag, [])
187
- paths2 = array_paths2.get(tag, [])
188
- path_dict1 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths1}
189
- path_dict2 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths2}
190
- common_keys = set(path_dict1.keys()) & set(path_dict2.keys())
191
- for key in common_keys:
192
- try:
193
- array1 = np.load(path_dict1[key])
194
- array2 = np.load(path_dict2[key])
195
- result = cls.compare_arrays(array1, array2)
196
- step, rank, micro_step, array_id = key
197
- data.append([
198
- step, rank, micro_step, array_id,
199
- list(array1.shape), list(array2.shape),
200
- result.same_percentage,
201
- result.first_mismatch_index,
202
- result.max_abs_error,
203
- result.max_relative_error,
204
- result.percentage_within_thousandth,
205
- result.percentage_within_hundredth
206
- ])
207
- except Exception as e:
208
- logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}")
184
+ data = []
185
+ paths1 = array_paths1.get(tag, [])
186
+ paths2 = array_paths2.get(tag, [])
187
+ path_dict1 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths1}
188
+ path_dict2 = {(step, rank, micro_step, array_id): path for step, rank, micro_step, array_id, path in paths2}
189
+ common_keys = set(path_dict1.keys()) & set(path_dict2.keys())
190
+ for key in common_keys:
191
+ try:
192
+ array1 = np.load(path_dict1[key])
193
+ array2 = np.load(path_dict2[key])
194
+ result = cls.compare_arrays(array1, array2)
195
+ step, rank, micro_step, array_id = key
196
+ data.append([
197
+ step, rank, micro_step, array_id,
198
+ list(array1.shape), list(array2.shape),
199
+ result.same_percentage,
200
+ result.first_mismatch_index,
201
+ result.max_abs_error,
202
+ result.max_relative_error,
203
+ result.percentage_within_thousandth,
204
+ result.percentage_within_hundredth
205
+ ])
206
+ except Exception as e:
207
+ logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}")
209
208
 
209
+ try:
210
210
  df = pd.DataFrame(data, columns=SingleComparator.result_header)
211
211
  df = df.sort_values(by=['step', 'rank', 'micro_step', 'id'])
212
212
  # 构建输出文件的完整路径
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ from collections import defaultdict
17
18
 
18
19
  from msprobe.core.common.file_utils import create_directory, save_json
19
20
  from msprobe.core.common.const import Const
@@ -36,7 +37,7 @@ class SingleSave:
36
37
  cls._instance.dump_path = dump_path
37
38
  cls._instance.rank = FmkAdp.get_rank_id()
38
39
  cls._instance.step_count = 0
39
- cls._instance.cache_dict = {}
40
+ cls._instance.tag_count = defaultdict(int)
40
41
  return cls._instance
41
42
 
42
43
  @staticmethod
@@ -109,13 +110,7 @@ class SingleSave:
109
110
  @classmethod
110
111
  def step(cls):
111
112
  instance = cls._instance
112
- for key, value in instance.cache_dict.items():
113
- if not value["have_micro_batch"]:
114
- cls.save_ex({key: value["data"][0]})
115
- else:
116
- for i, data in enumerate(value["data"]):
117
- cls.save_ex({key: data}, micro_batch=i)
118
- instance.cache_dict = {}
113
+ instance.tag_count = defaultdict(int)
119
114
  instance.step_count += 1
120
115
 
121
116
  @classmethod
@@ -127,14 +122,8 @@ class SingleSave:
127
122
  "Skip current save process.")
128
123
  return
129
124
  for key, value in data.items():
130
- if key not in instance.cache_dict:
131
- instance.cache_dict[key] = {
132
- "have_micro_batch": False,
133
- "data": [value]
134
- }
135
- else:
136
- instance.cache_dict[key]["have_micro_batch"] = True
137
- instance.cache_dict[key]["data"].append(value)
125
+ cls.save_ex({key: value}, micro_batch=instance.tag_count[key])
126
+ instance.tag_count[key] += 1
138
127
 
139
128
  @classmethod
140
129
  def _analyze_list_tuple_data(cls, data, data_name=None, save_dir=None):
@@ -16,6 +16,7 @@ pip install mindstudio-probe
16
16
 
17
17
  |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码|
18
18
  |:--:|:--:|:--:|:--:|:--:|:--:|
19
+ |8.1.0|2025.6.14|1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0|[mindstudio_probe-8.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.0-py3-none-any.whl)|d10c0a57d073bbe7c681042a11e93a0eaaaf5aa45e1cec997142ce2593d77afd|
19
20
  |8.0.0|2025.5.07|1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0|[mindstudio_probe-8.0.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.0/mindstudio_probe-8.0.0-py3-none-any.whl)|6810eade7ae99e3b24657d5cab251119882decd791aa76a7aeeb94dea767daec|
20
21
  |1.3.0|2025.4.17|1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0|[mindstudio_probe-1.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.3/mindstudio_probe-1.3.0-py3-none-any.whl)|85dbc5518b5c23d29c67d7b85d662517d0318352f372891f8d91e73e71b439c3|
21
22
  |1.2.2|2025.3.03|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|961411bb460d327ea51d6ca4d0c8e8c5565f07c0852d7b8592b781ca35b87212|
@@ -471,12 +471,14 @@ debugger.step()
471
471
  | | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.pt。
472
472
  │ | | ├── dump.json
473
473
  │ | | ├── stack.json
474
+ │ | | ├── dump_error_info.log
474
475
  │ | | └── construct.json
475
476
  │ | ├── rank1
476
477
  | | | ├── dump_tensor_data
477
478
  | | | | └── ...
478
479
  │ | | ├── dump.json
479
480
  │ | | ├── stack.json
481
+ │ | | ├── dump_error_info.log
480
482
  | | | └── construct.json
481
483
  │ | ├── ...
482
484
  │ | |
@@ -488,6 +490,7 @@ debugger.step()
488
490
  * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。
489
491
  * `dump_tensor_data`:保存采集到的张量数据。
490
492
  * `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-PyTorch场景下的dump.json文件)。
493
+ * `dump_error_info.log`: 仅在dump工具报错时拥有此记录日志,用于记录dump错误日志。
491
494
  * `stack.json`:API/Module的调用栈信息。
492
495
  * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。
493
496
 
@@ -496,12 +496,14 @@ dump 结果目录结构示例如下:
496
496
  | | | | # 当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Cell}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Cell.0.relu.ReLU.forward.0.input.0.npy。
497
497
  │ | | ├── dump.json
498
498
  │ | | ├── stack.json
499
+ │ | | ├── dump_error_info.log
499
500
  │ | | └── construct.json
500
501
  │ | ├── rank1
501
502
  | | | ├── dump_tensor_data
502
503
  | | | | └── ...
503
504
  │ | | ├── dump.json
504
505
  │ | | ├── stack.json
506
+ │ | | ├── dump_error_info.log
505
507
  | | | └── construct.json
506
508
  │ | ├── ...
507
509
  │ | |
@@ -514,6 +516,7 @@ dump 结果目录结构示例如下:
514
516
  * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。
515
517
  * `dump_tensor_data`:保存采集到的张量数据。
516
518
  * `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-mindspore-场景下的-dumpjson-文件)。
519
+ * `dump_error_info.log`: 仅在dump工具报错时拥有此记录日志,用于记录dump错误日志。
517
520
  * `stack.json`:API/Cell的调用栈信息。
518
521
  * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。
519
522
 
@@ -88,7 +88,7 @@ extendedKeyUsage = serverAuth
88
88
  EOF
89
89
  )
90
90
 
91
- # 生成server公私钥,server_password
91
+ # 生成server公私钥,其中server_password为私钥加密口令,仅作演示,请更换使用
92
92
  openssl genrsa -aes256 -passout pass:server_password -out server.key 3072
93
93
  # 基于server公私钥生成签名请求
94
94
  openssl req -new -key server.key -passin pass:server_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out server.csr
@@ -115,7 +115,7 @@ default_ca = CA_default
115
115
  database = ./index.txt
116
116
  default_md = sha256
117
117
 
118
- # 吊销证书 client.crt
118
+ # 吊销证书 client.crt,其中ca_password为CA私钥加密口令,与CA创建时保持一致
119
119
  openssl ca -revoke client.crt -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password
120
120
  # 生成CRL文件
121
121
  openssl ca -gencrl -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password -out crl.pem -crldays 30