mindstudio-probe 8.2.0__py3-none-any.whl → 8.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/RECORD +90 -79
  3. msprobe/README.md +7 -5
  4. msprobe/core/common/const.py +6 -0
  5. msprobe/core/common/db_manager.py +35 -4
  6. msprobe/core/common/file_utils.py +105 -27
  7. msprobe/core/common/framework_adapter.py +7 -6
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/utils.py +14 -3
  10. msprobe/core/compare/find_first/analyzer.py +8 -7
  11. msprobe/core/compare/find_first/graph.py +11 -3
  12. msprobe/core/compare/find_first/utils.py +2 -1
  13. msprobe/core/compare/highlight.py +13 -6
  14. msprobe/core/compare/multiprocessing_compute.py +17 -10
  15. msprobe/core/compare/utils.py +14 -5
  16. msprobe/core/data_dump/data_collector.py +18 -21
  17. msprobe/core/data_dump/data_processor/pytorch_processor.py +43 -20
  18. msprobe/core/data_dump/json_writer.py +18 -8
  19. msprobe/core/data_dump/scope.py +4 -6
  20. msprobe/core/hook_manager.py +37 -3
  21. msprobe/core/service.py +18 -5
  22. msprobe/core/single_save/single_comparator.py +16 -3
  23. msprobe/docs/01.installation.md +7 -5
  24. msprobe/docs/02.config_introduction.md +14 -1
  25. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  26. msprobe/docs/06.data_dump_MindSpore.md +1 -1
  27. msprobe/docs/08.accuracy_checker_online_PyTorch.md +295 -0
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +46 -5
  29. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  30. msprobe/docs/15.free_benchmarking_PyTorch.md +1 -1
  31. msprobe/docs/19.monitor.md +2 -0
  32. msprobe/docs/21.visualization_PyTorch.md +15 -80
  33. msprobe/docs/22.visualization_MindSpore.md +20 -104
  34. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  35. msprobe/docs/25.tool_function_introduction.md +1 -0
  36. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  37. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  38. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  39. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  40. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  41. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  42. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  43. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  44. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  45. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  46. msprobe/mindspore/cell_processor.py +33 -5
  47. msprobe/mindspore/compare/common_dir_compare.py +22 -26
  48. msprobe/mindspore/compare/utils.py +1 -2
  49. msprobe/mindspore/debugger/precision_debugger.py +1 -1
  50. msprobe/mindspore/dump/cell_dump_process.py +73 -62
  51. msprobe/mindspore/dump/graph_mode_cell_dump.py +21 -10
  52. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +2 -0
  53. msprobe/msprobe.py +6 -4
  54. msprobe/pytorch/api_accuracy_checker/common/config.py +36 -3
  55. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +24 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare.py +12 -2
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +6 -1
  58. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  59. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +132 -12
  60. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  61. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +205 -0
  62. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +378 -0
  63. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +239 -0
  64. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  65. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +250 -0
  66. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  67. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +198 -0
  68. msprobe/pytorch/attl_manager.py +65 -0
  69. msprobe/pytorch/common/utils.py +22 -2
  70. msprobe/pytorch/compare/utils.py +3 -3
  71. msprobe/pytorch/debugger/debugger_config.py +10 -0
  72. msprobe/pytorch/dump/module_dump/hook_wrapper.py +34 -7
  73. msprobe/pytorch/dump/module_dump/module_processer.py +23 -10
  74. msprobe/pytorch/hook_module/api_register.py +6 -1
  75. msprobe/pytorch/monitor/module_hook.py +28 -9
  76. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  77. msprobe/pytorch/pt_config.py +57 -2
  78. msprobe/pytorch/pytorch_service.py +11 -2
  79. msprobe/visualization/builder/graph_builder.py +170 -64
  80. msprobe/visualization/builder/graph_merger.py +0 -1
  81. msprobe/visualization/builder/msprobe_adapter.py +1 -1
  82. msprobe/visualization/db_utils.py +25 -2
  83. msprobe/visualization/graph/base_node.py +0 -24
  84. msprobe/visualization/graph/graph.py +5 -14
  85. msprobe/visualization/graph_service.py +29 -53
  86. msprobe/visualization/utils.py +11 -1
  87. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/LICENSE +0 -0
  88. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/WHEEL +0 -0
  89. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/entry_points.txt +0 -0
  90. {mindstudio_probe-8.2.0.dist-info → mindstudio_probe-8.3.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from msprobe.core.data_dump.json_writer import DataWriter
23
23
  from msprobe.core.common.log import logger
24
24
  from msprobe.core.common.const import Const
25
25
  from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
26
+ from msprobe.core.common.megatron_utils import MegatronStepInfo, get_micro_step, is_megatron
26
27
 
27
28
 
28
29
  def build_data_collector(config):
@@ -270,15 +271,20 @@ class DataCollector:
270
271
  if self.config.level not in DataCollector.level_without_construct:
271
272
  if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
272
273
  if self.optimizer_status_first_start[self.optimizer_status]:
273
- self.data_writer.update_construct({self.optimizer_status: None})
274
+ self.data_writer.update_construct(
275
+ {self.optimizer_status: None if not is_megatron() else [None, get_micro_step()]})
274
276
  self.optimizer_status_first_start[self.optimizer_status] = False
275
- self.data_writer.update_construct({name: self.optimizer_status})
277
+ self.data_writer.update_construct(
278
+ {name: self.optimizer_status if not is_megatron() else [self.optimizer_status, get_micro_step()]})
276
279
  else:
277
280
  if self.config.level == Const.LEVEL_MIX and \
278
281
  not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
279
282
  self.data_writer.update_construct(
280
283
  {name: self.module_processor.api_parent_node.get(threading.get_ident())}
281
284
  )
285
+ if MegatronStepInfo.is_megatron:
286
+ micro_step_number = max(MegatronStepInfo.forward_micro_step, MegatronStepInfo.backward_micro_step)
287
+ self.data_writer.update_construct({Const.MEGATRON_MICRO_STEP_NUMBER: micro_step_number})
282
288
 
283
289
  self.data_writer.update_construct(self.module_processor.module_node)
284
290
 
@@ -302,25 +308,16 @@ class DataCollector:
302
308
  self.data_processor.update_iter(current_iter)
303
309
 
304
310
  def params_data_collect(self, name, param_name, pid, data):
305
- try:
306
- grad_name = name + Const.SEP + Const.PARAMS_GRAD
307
- self.update_api_or_module_name(grad_name)
308
- if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
309
- if self.data_writer.cache_data.get("data"):
310
- self.data_writer.cache_data.get("data").pop(grad_name, None)
311
- self.params_grad_record[grad_name] = False
312
- return
313
- data_info = self.data_processor.analyze_params(grad_name, param_name, data)
314
- self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
315
- self.params_grad_record[grad_name] = False
316
- except Exception as e:
317
- error_type = type(e).__name__
318
- tb = traceback.format_exc()
319
- self.data_writer.write_error_log(
320
- f"[ERROR] params_data_collect failed: "
321
- f"name={name}, param_name={param_name}, pid={pid}\n{tb}",
322
- error_type=error_type
323
- )
311
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
312
+ self.update_api_or_module_name(grad_name)
313
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
314
+ if self.data_writer.cache_data.get("data"):
315
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
316
+ self.params_grad_record[grad_name] = False
317
+ return
318
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
319
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
320
+ self.params_grad_record[grad_name] = False
324
321
 
325
322
  def params_data_collect_in_bw_hook(self, params_dict, name):
326
323
  try:
@@ -13,13 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import ctypes
16
17
  import os
17
18
  import zlib
18
- import ctypes
19
19
  from collections.abc import Iterable
20
+ from concurrent.futures import ThreadPoolExecutor
20
21
  from dataclasses import asdict
21
22
  from typing import List
22
- from concurrent.futures import ThreadPoolExecutor
23
23
 
24
24
  import numpy as np
25
25
  import torch
@@ -29,7 +29,6 @@ from torch.distributed.distributed_c10d import _get_default_group
29
29
  from msprobe.core.common.const import Const
30
30
  from msprobe.core.common.decorator import recursion_depth_decorator
31
31
  from msprobe.core.common.exceptions import MsprobeException
32
- from msprobe.core.common.file_utils import path_len_exceeds_limit
33
32
  from msprobe.core.common.log import logger
34
33
  from msprobe.core.common.utils import convert_tuple, is_int
35
34
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
@@ -48,15 +47,28 @@ class TensorHandler:
48
47
  def __init__(self):
49
48
  self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor")
50
49
  self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor")
50
+ self.has_async_collective_tensor = hasattr(dist, "_functional_collectives") and \
51
+ hasattr(dist._functional_collectives, "AsyncCollectiveTensor")
52
+
53
+ @staticmethod
54
+ def free_tensor(tensor, tensor_name):
55
+ try:
56
+ tensor.untyped_storage().resize_(0)
57
+ except Exception as e:
58
+ logger.warning(f"Failed to free tensor: {tensor_name}, the detail info: {e}.")
51
59
 
52
60
  def is_dtensor(self, tensor):
53
- return self.has_dtensor and isinstance(tensor, torch.distributed.tensor.DTensor)
61
+ return self.has_dtensor and isinstance(tensor, dist.tensor.DTensor)
54
62
 
55
63
  def is_fake_tensor(self, tensor):
56
64
  return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor)
57
65
 
66
+ def is_async_collective_tensor(self, tensor):
67
+ return self.has_async_collective_tensor and \
68
+ isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor)
69
+
58
70
  def is_empty_data(self, tensor):
59
- return tensor.is_meta or self.is_fake_tensor(tensor)
71
+ return tensor.is_meta or self.is_fake_tensor(tensor) or self.is_async_collective_tensor(tensor)
60
72
 
61
73
  def convert_common_tensor(self, tensor):
62
74
  if self.is_dtensor(tensor):
@@ -71,6 +83,8 @@ class TensorHandler:
71
83
  return Const.DTENSOR_TYPE
72
84
  if self.is_fake_tensor(tensor):
73
85
  return Const.FAKE_TENSOR_TYPE
86
+ if self.is_async_collective_tensor(tensor):
87
+ return Const.AC_TENSOR_TYPE
74
88
  return Const.TENSOR_TYPE
75
89
 
76
90
  def get_dtensor_info(self, tensor):
@@ -94,6 +108,18 @@ class TensorHandler:
94
108
  dtensor_info.update({"placements": placements})
95
109
  return dtensor_info
96
110
 
111
+ def save_tensor(self, tensor, file_path):
112
+ common_tensor = self.convert_common_tensor(tensor)
113
+ if self.is_empty_data(common_tensor):
114
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
115
+ return
116
+ if common_tensor.untyped_storage().data_ptr() == 0:
117
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
118
+ return
119
+ saved_tensor = common_tensor.clone().contiguous().detach()
120
+ save_pt(saved_tensor, file_path)
121
+ self.free_tensor(saved_tensor, file_path)
122
+
97
123
 
98
124
  class PytorchDataProcessor(BaseDataProcessor):
99
125
  pytorch_special_type = (
@@ -288,7 +314,7 @@ class PytorchDataProcessor(BaseDataProcessor):
288
314
 
289
315
  def dump_async_data(self):
290
316
  for file_path, tensor in self._async_dump_cache.items():
291
- save_pt(tensor.contiguous(), file_path)
317
+ self.tensor_handler.save_tensor(tensor, file_path)
292
318
  self._async_dump_cache.clear()
293
319
 
294
320
  def analyze_single_element(self, element, suffix_stack):
@@ -385,24 +411,24 @@ class PytorchDataProcessor(BaseDataProcessor):
385
411
  def _analyze_and_save_tensor(self, tensor, suffix):
386
412
  dump_data_name, file_path = self.get_save_file_path(suffix)
387
413
  single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
388
- if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0:
389
- logger.debug(
390
- "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, "
391
- f"the current api/module name is {self.current_api_or_module_name}."
392
- )
414
+ common_tensor = self.tensor_handler.convert_common_tensor(tensor)
415
+ if self.tensor_handler.is_empty_data(common_tensor):
416
+ logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.")
417
+ return single_arg
418
+ if common_tensor.untyped_storage().data_ptr() == 0:
419
+ logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.")
393
420
  return single_arg
394
421
 
395
422
  single_arg.update({"data_name": dump_data_name})
396
423
  if self.config.async_dump:
397
- self._async_dump_cache[file_path] = tensor.clone().detach()
424
+ self._async_dump_cache[file_path] = common_tensor.clone().detach()
398
425
  else:
399
- saved_tensor = tensor.clone().contiguous().detach()
400
- save_pt(saved_tensor, file_path)
426
+ self.tensor_handler.save_tensor(common_tensor, file_path)
401
427
  return single_arg
402
428
 
403
429
  def _analyze_and_save_ndarray(self, ndarray, suffix):
404
430
  dump_data_name, file_path = self.get_save_file_path(suffix)
405
- save_pt(torch.tensor(ndarray), file_path)
431
+ self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path)
406
432
  ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
407
433
  ndarray_json.update({"data_name": dump_data_name})
408
434
  return ndarray_json
@@ -493,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
493
519
  self._analyze_maybe_overflow_flag()
494
520
  if self.has_overflow:
495
521
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
496
- save_pt(tensor.clone().contiguous().detach(), file_path)
522
+ self.tensor_handler.save_tensor(tensor, file_path)
497
523
  self.real_overflow_nums += 1
498
524
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
499
525
  logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
@@ -538,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
538
564
 
539
565
  def _analyze_tensor(self, tensor, suffix):
540
566
  dump_data_name, file_path = self.get_save_file_path(suffix)
541
- if not path_len_exceeds_limit(file_path):
542
- self.cached_tensors_and_file_paths.update({file_path: tensor})
543
- else:
544
- logger.warning(f'The file path {file_path} length exceeds limit.')
567
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
545
568
  single_arg = super()._analyze_tensor(tensor, suffix)
546
569
  single_arg.update({"data_name": dump_data_name})
547
570
  if not self.has_overflow and self.support_inf_nan:
@@ -13,18 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import concurrent
17
+ import copy
16
18
  import csv
17
19
  import os
18
- import copy
19
20
  import threading
20
21
  import traceback
21
22
  from datetime import datetime, timezone, timedelta
22
23
 
23
- import concurrent
24
24
  from msprobe.core.common.const import Const, FileCheckConst
25
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
26
- from msprobe.core.common.log import logger
27
25
  from msprobe.core.common.decorator import recursion_depth_decorator
26
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, check_path_before_create
27
+ from msprobe.core.common.log import logger
28
28
 
29
29
  lock = threading.Lock()
30
30
 
@@ -40,6 +40,7 @@ class DataWriter:
40
40
  self.debug_file_path = None
41
41
  self.dump_error_info_path = None
42
42
  self.flush_size = 1000
43
+ self.md5_flush_size = 5000
43
44
  self.larger_flush_size = 20000
44
45
  self.cache_data = {}
45
46
  self.cache_stack = {}
@@ -49,6 +50,7 @@ class DataWriter:
49
50
  self._error_log_initialized = False
50
51
  self._cache_logged_error_types = set()
51
52
  self.crc32_stack_list = []
53
+ self.data_updated = False
52
54
 
53
55
  @staticmethod
54
56
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -60,7 +62,7 @@ class DataWriter:
60
62
  spawn_writer = csv.writer(csv_file)
61
63
  if not is_exists:
62
64
  spawn_writer.writerow(result_header)
63
- spawn_writer.writerows([result,])
65
+ spawn_writer.writerows([result, ])
64
66
  is_new_file = not is_exists
65
67
  if is_new_file:
66
68
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -190,7 +192,7 @@ class DataWriter:
190
192
  summary_mode = getattr(cfg, "summary_mode", None)
191
193
 
192
194
  if summary_mode == Const.MD5:
193
- threshold = self.flush_size
195
+ threshold = self.md5_flush_size
194
196
  else:
195
197
  threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
196
198
 
@@ -238,6 +240,7 @@ class DataWriter:
238
240
  logger.warning(f"The dump data({dump_data}) should be a dict.")
239
241
  return
240
242
 
243
+ self.data_updated = True
241
244
  key = next(iter(new_data.keys()))
242
245
  if key in dump_data:
243
246
  dump_data.get(key).update(new_data.get(key))
@@ -246,6 +249,7 @@ class DataWriter:
246
249
 
247
250
  def update_stack(self, name, stack_data):
248
251
  with lock:
252
+ self.data_updated = True
249
253
  api_list = self.cache_stack.get(stack_data)
250
254
  if api_list is None:
251
255
  self.cache_stack.update({stack_data: [name]})
@@ -254,10 +258,12 @@ class DataWriter:
254
258
 
255
259
  def update_construct(self, new_data):
256
260
  with lock:
261
+ self.data_updated = True
257
262
  self.cache_construct.update(new_data)
258
263
 
259
264
  def update_debug(self, new_data):
260
265
  with lock:
266
+ self.data_updated = True
261
267
  self.cache_debug['data'].update(new_data)
262
268
 
263
269
  def write_data_json(self, file_path):
@@ -324,17 +330,21 @@ class DataWriter:
324
330
  stat_result = self.flush_stat_stack()
325
331
  # 遍历 cache_data,将占位符替换为最终统计值
326
332
  if stat_result:
333
+ self.data_updated = True
327
334
  self._replace_stat_placeholders(self.cache_data, stat_result)
328
335
  if self.cache_debug:
329
336
  self._replace_stat_placeholders(self.cache_debug, stat_result)
330
337
 
331
- # 2) 再 flush CRC32
332
338
  crc32_result = self.flush_crc32_stack()
333
339
  if crc32_result:
340
+ self.data_updated = True
334
341
  self._replace_crc32_placeholders(self.cache_data, crc32_result)
335
342
  if self.cache_debug:
336
343
  self._replace_crc32_placeholders(self.cache_debug, crc32_result)
337
344
 
345
+ if not self.data_updated:
346
+ return
347
+
338
348
  if self.cache_data:
339
349
  self.write_data_json(self.dump_file_path)
340
350
  if self.cache_stack:
@@ -343,4 +353,4 @@ class DataWriter:
343
353
  self.write_construct_info_json(self.construct_file_path)
344
354
  if self.cache_debug:
345
355
  self.write_debug_info_json(self.debug_file_path)
346
-
356
+ self.data_updated = False
@@ -69,8 +69,7 @@ class BaseScope(ABC):
69
69
  self.scope = scope
70
70
  self.api_list = api_list
71
71
 
72
- @staticmethod
73
- def rectify_args(scope, api_list):
72
+ def rectify_args(self, scope, api_list):
74
73
  if not isinstance(api_list, list):
75
74
  raise ScopeException(ScopeException.InvalidApiStr,
76
75
  f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
@@ -104,12 +103,11 @@ class BaseScope(ABC):
104
103
 
105
104
 
106
105
  class ListScope(BaseScope):
107
- @staticmethod
108
- def rectify_args(scope, api_list):
106
+ def rectify_args(self, scope, api_list):
109
107
  if scope and api_list:
110
108
  raise ScopeException(ScopeException.ArgConflict,
111
109
  f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
112
- return super(ListScope, ListScope).rectify_args(scope, api_list)
110
+ return super().rectify_args(scope, api_list)
113
111
 
114
112
  def check(self, name):
115
113
  if not self.scope or name in self.scope:
@@ -147,7 +145,7 @@ class RangeScope(BaseScope, ABC):
147
145
  f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
148
146
 
149
147
  def rectify_args(self, scope, api_list):
150
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
148
+ scope, api_list = super().rectify_args(scope, api_list)
151
149
  if scope and len(scope) != 2:
152
150
  raise ScopeException(ScopeException.InvalidScope,
153
151
  f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import gc
16
17
  import os
17
18
  import threading
18
19
  from abc import ABC, abstractmethod
@@ -45,9 +46,10 @@ class BaseHookManager(ABC):
45
46
  hook_handle_dict = {}
46
47
  params_grad_info = {}
47
48
 
48
- def __init__(self, data_collector, config):
49
+ def __init__(self, data_collector, config, attl_manager=None):
49
50
  self.data_collector = data_collector
50
51
  self.config = config
52
+ self.attl_manager = attl_manager
51
53
 
52
54
  @property
53
55
  def _pid(self):
@@ -62,9 +64,20 @@ class BaseHookManager(ABC):
62
64
  def reset_status():
63
65
  BaseHookManager.inner_switch = defaultdict(bool)
64
66
  BaseHookManager.inner_api_count = defaultdict(int)
65
- BaseHookManager.hook_handle_dict.clear()
66
67
  BaseHookManager.params_grad_info.clear()
67
68
 
69
+ @staticmethod
70
+ def ensure_gc_enabled():
71
+ is_gc_disabled = not gc.isenabled()
72
+ if is_gc_disabled:
73
+ gc.enable()
74
+ return is_gc_disabled
75
+
76
+ @staticmethod
77
+ def restore_gc_state(original_state):
78
+ if original_state:
79
+ gc.disable()
80
+
68
81
  @staticmethod
69
82
  def _clear_input_kwargs(module, tid):
70
83
  if hasattr(module, 'msprobe_input_kwargs') and tid in module.msprobe_input_kwargs:
@@ -168,9 +181,11 @@ class BaseHookManager(ABC):
168
181
  if not self._should_execute_hook(Const.MODULE, tid):
169
182
  return
170
183
  with ThreadSafe():
184
+ original_state = self.ensure_gc_enabled()
171
185
  BaseHookManager.inner_switch[tid] = True
172
186
  self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
173
187
  BaseHookManager.inner_switch[tid] = False
188
+ self.restore_gc_state(original_state)
174
189
  return
175
190
 
176
191
  return hook_fn
@@ -185,6 +200,7 @@ class BaseHookManager(ABC):
185
200
  return None
186
201
 
187
202
  with ThreadSafe():
203
+ original_state = self.ensure_gc_enabled()
188
204
  self._register_forward_hook(module, api_name)
189
205
  BaseHookManager.inner_api_count[tid] += 1
190
206
  if BaseHookManager.inner_api_count[tid] != 1:
@@ -200,6 +216,10 @@ class BaseHookManager(ABC):
200
216
 
201
217
  args = self._register_backward_hook(module, full_backward_name, args)
202
218
  with self._no_grad_context():
219
+ if getattr(self.config, "online_run_ut", False):
220
+ BaseHookManager.inner_switch[tid] = False
221
+ ThreadSafe.release()
222
+ return
203
223
  self.data_collector.update_api_or_module_name(full_forward_name)
204
224
  self.data_collector.forward_input_data_collect(
205
225
  full_forward_name,
@@ -209,6 +229,7 @@ class BaseHookManager(ABC):
209
229
  self._is_recompute
210
230
  )
211
231
  BaseHookManager.inner_switch[tid] = False
232
+ self.restore_gc_state(original_state)
212
233
  return args
213
234
 
214
235
  return forward_pre_hook
@@ -221,6 +242,7 @@ class BaseHookManager(ABC):
221
242
  return None
222
243
 
223
244
  with ThreadSafe():
245
+ original_state = self.ensure_gc_enabled()
224
246
  if hook_type == Const.API:
225
247
  if BaseHookManager.inner_api_count[tid] != 1:
226
248
  if BaseHookManager.inner_api_count[tid] > 1:
@@ -243,6 +265,13 @@ class BaseHookManager(ABC):
243
265
  output = self._register_backward_pre_hook(module, full_backward_name, output)
244
266
 
245
267
  with self._no_grad_context():
268
+ if getattr(self.config, "online_run_ut", False):
269
+ if self.data_collector.scope and not self.data_collector.scope.check(full_name):
270
+ return None
271
+ if self.attl_manager:
272
+ self.attl_manager.attl_send(full_name, args, kwargs, output)
273
+ BaseHookManager.inner_switch[tid] = False
274
+ return None
246
275
  if hook_type == Const.MODULE:
247
276
  params_dict = self._get_params_dict(module)
248
277
  setattr(module_input_output, Const.PARAMS, params_dict)
@@ -276,6 +305,7 @@ class BaseHookManager(ABC):
276
305
  return forward_new_output
277
306
 
278
307
  BaseHookManager.inner_switch[tid] = False
308
+ self.restore_gc_state(original_state)
279
309
  return output
280
310
 
281
311
  return forward_hook
@@ -287,9 +317,12 @@ class BaseHookManager(ABC):
287
317
  return
288
318
 
289
319
  with ThreadSafe():
320
+ original_state = self.ensure_gc_enabled()
290
321
  BaseHookManager.inner_switch[tid] = True
291
322
  self.data_collector.update_api_or_module_name(full_name)
292
-
323
+ if getattr(self.config, "online_run_ut", False):
324
+ BaseHookManager.inner_switch[tid] = False
325
+ return
293
326
  need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
294
327
  if need_exchange:
295
328
  module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
@@ -306,5 +339,6 @@ class BaseHookManager(ABC):
306
339
  params_dict = self._get_params_dict(module)
307
340
  self.data_collector.params_data_collect_in_bw_hook(params_dict, full_name)
308
341
  BaseHookManager.inner_switch[tid] = False
342
+ self.restore_gc_state(original_state)
309
343
 
310
344
  return backward_hook
msprobe/core/service.py CHANGED
@@ -26,6 +26,7 @@ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggr
26
26
  from msprobe.core.data_dump.api_registry import ApiRegistry
27
27
  from msprobe.core.data_dump.data_collector import build_data_collector
28
28
  from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json
29
+ from msprobe.core.common.megatron_utils import MegatronStepInfo
29
30
 
30
31
 
31
32
  class BaseService(ABC):
@@ -34,6 +35,7 @@ class BaseService(ABC):
34
35
  self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
35
36
  self.model = None
36
37
  self.data_collector = build_data_collector(self.config)
38
+ self.attl_manager = None
37
39
  self.current_iter = 0
38
40
  self.loop = 0
39
41
  self.init_step = 0
@@ -89,6 +91,10 @@ class BaseService(ABC):
89
91
  self.config.task in self.data_collector.tasks_need_tensor_data or
90
92
  (self.config.task == Const.STATISTICS and self.config.tensor_list)
91
93
  )
94
+
95
+ @property
96
+ def _is_online_run_ut(self):
97
+ return getattr(self.config, "online_run_ut", False)
92
98
 
93
99
  @property
94
100
  @abstractmethod
@@ -140,9 +146,11 @@ class BaseService(ABC):
140
146
  self.primitive_switch = True
141
147
  self._change_jit_switch(True)
142
148
  self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
143
-
144
- self.create_dirs()
145
- self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
149
+ if self._is_online_run_ut:
150
+ self._run_ut_dispatch(True)
151
+ else:
152
+ self.create_dirs()
153
+ self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
146
154
 
147
155
  def stop(self):
148
156
  """通用stop模板"""
@@ -157,7 +165,8 @@ class BaseService(ABC):
157
165
  self._change_jit_switch(False)
158
166
  if self._is_l2_level:
159
167
  return
160
-
168
+ if self._is_online_run_ut:
169
+ self._run_ut_dispatch(False)
161
170
  self._process_async_dump()
162
171
  self.data_collector.write_json()
163
172
 
@@ -170,6 +179,7 @@ class BaseService(ABC):
170
179
  self.currrent_step_first_debug_save = True
171
180
  self.loop += 1
172
181
  self._reset_status()
182
+ MegatronStepInfo.reset()
173
183
 
174
184
  def save(self, variable, name, save_backward):
175
185
  '''
@@ -256,6 +266,8 @@ class BaseService(ABC):
256
266
  end_service = self.config.step and self.current_iter > max(self.config.step) or \
257
267
  self.data_collector and self.data_collector.data_processor.is_terminated
258
268
  if end_service:
269
+ if self._is_online_run_ut and self.attl_manager:
270
+ self.attl_manager.attl_stop()
259
271
  self.primitive_switch = False
260
272
  self._change_jit_switch(False)
261
273
  Runtime.is_running = False
@@ -298,7 +310,8 @@ class BaseService(ABC):
298
310
  if root_model and isinstance(root_model, list):
299
311
  root_model = root_model[0]
300
312
  self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
301
-
313
+ if self._is_online_run_ut:
314
+ return
302
315
  root_model.register_forward_pre_hook(infer_hook)
303
316
 
304
317
  def _create_l2_dirs(self, cur_rank):
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ import re
17
18
  import multiprocessing
18
19
  from dataclasses import dataclass
19
20
 
@@ -70,6 +71,9 @@ class SingleComparator:
70
71
  比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比
71
72
  """
72
73
  # 计算每个维度上的最小尺寸
74
+ if array1.ndim != array2.ndim:
75
+ array1 = array1.flatten()
76
+ array2 = array2.flatten()
73
77
  min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)]
74
78
  # 截取数组到相同的形状
75
79
  sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)]
@@ -176,9 +180,18 @@ class SingleComparator:
176
180
  continue
177
181
  for step, step_path in cls.get_steps(tag_path):
178
182
  for rank, rank_path in cls.get_ranks(step_path):
179
- for micro_step, micro_step_path in cls.get_micro_steps(rank_path):
180
- for array_id, array_path in cls.get_arrays(micro_step_path):
181
- array_paths.setdefault(tag, []).append((step, rank, micro_step, array_id, array_path))
183
+ for item in os.listdir(rank_path):
184
+ next_path = os.path.join(rank_path, item)
185
+ if re.match(r"micro_step(\d+)", item):
186
+ micro_step = re.match(r"micro_step(\d+)", item).group(1)
187
+ for array_id, array_path in cls.get_arrays(next_path):
188
+ array_paths.setdefault(tag, []).append(
189
+ (step, rank, int(micro_step), array_id, array_path))
190
+ elif re.match(r"\w{1,100}_(\d{1,100})\.npy", item):
191
+ array_id = re.match(r"\w{1,100}_(\d{1,100})\.npy", item).group(1)
192
+ array_paths.setdefault(tag, []).append((step, rank, 0, int(array_id), next_path))
193
+ else:
194
+ array_paths.setdefault(tag, []).append((step, rank, 0, 0, next_path))
182
195
  return array_paths
183
196
 
184
197
  @classmethod
@@ -16,6 +16,7 @@ pip install mindstudio-probe
16
16
 
17
17
  | 版本 | 发布日期 |支持 PyTorch 版本|支持 MindSpore 版本| 下载链接 |校验码|
18
18
  |:-----:|:----------:|:--:|:--:|:----------------------------------------------------------------------------------------------------------------------------------:|:--:|
19
+ | 8.2.0 | 2025.9.03 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.0-py3-none-any.whl) |bbc1577d76754adf987069308177d3e0a04e36de9c7f22e75c34cf4ad0ce1af2|
19
20
  | 8.1.2 | 2025.8.01 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.2-py3-none-any.whl) |ff07bb81fddd3b8f3096d119ca1481bde8fdb24f10644def5250caad727448ab|
20
21
  | 8.1.1 | 2025.6.20 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.1-py3-none-any.whl) |2aad10a243575544d7feef552caf4d06aa93028488ebd0bbc9aa350379da859d|
21
22
  | 8.1.0 | 2025.6.14 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.0-py3-none-any.whl) |d10c0a57d073bbe7c681042a11e93a0eaaaf5aa45e1cec997142ce2593d77afd|
@@ -45,12 +46,12 @@ pip install ./mindstudio_probe-{version}-py3-none-any.whl # 安装whl包
45
46
  ## 3 从源码安装
46
47
 
47
48
  ```shell
48
- git clone https://gitee.com/ascend/mstt.git
49
+ git clone https://gitcode.com/Ascend/mstt.git
49
50
  cd mstt/debug/accuracy_tools
50
51
 
51
52
  pip install setuptools wheel
52
53
 
53
- python setup.py bdist_wheel [--include-mod=[adump]]
54
+ python setup.py bdist_wheel [--include-mod=[adump]] [--no-check]
54
55
  cd ./dist
55
56
  pip install ./mindstudio_probe*.whl
56
57
  ```
@@ -58,6 +59,7 @@ pip install ./mindstudio_probe*.whl
58
59
  |参数|说明|是否必选|
59
60
  |--|--|:--:|
60
61
  |--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。<br>&#8226; adump模块用于MindSpore静态图场景L2级别的dump。<br>&#8226; 仅MindSpore 2.5.0及以上版本支持adump模块。<br>&#8226; 若使用源码安装,编译环境需支持GCC 7.5或以上版本,和CMake 3.14或以上版本。<br>&#8226; 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否|
62
+ |--no-check|指定可选模块`adump`后,会下载所依赖的三方库包,下载过程会进行证书校验。--no-check可以跳过证书校验。|否|
61
63
 
62
64
  # 特性变更说明
63
65
 
@@ -212,7 +214,7 @@ pip show mindstudio-probe
212
214
  Name: mindstudio-probe
213
215
  Version: 1.0.x
214
216
  Summary: Pytorch Ascend Probe Utils
215
- Home-page: https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe
217
+ Home-page: https://gitcode.com/Ascend/mstt/tree/master/debug/accuracy_tools/msprobe
216
218
  Author: Ascend Team
217
219
  Author-email: pmail_mindstudio@huawei.com
218
220
  License: Apache License 2.0
@@ -225,7 +227,7 @@ Required-by:
225
227
 
226
228
  ## 1 安装 CANN 包
227
229
 
228
- 1.1 根据 CPU 架构和 NPU 型号选择 toolkit 和 kernel,可以参考 [CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F700%2Fenvdeployment%2Finstg%2Finstg_0001.html)和[昇腾社区](https://www.hiascend.cn/developer/download/community/result?module=cann)。
230
+ 1.1 根据 CPU 架构和 NPU 型号选择 toolkit 和 kernel,可以参考 [CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0001.html)和[昇腾社区](https://www.hiascend.cn/developer/download/community/result?module=cann)。
229
231
 
230
232
  1.2 运行示例
231
233
  ```bash
@@ -239,7 +241,7 @@ source {cann_path}/ascend-toolkit/set_env.sh
239
241
  ```
240
242
  ## 2 安装 PyTorch_NPU
241
243
 
242
- 链接:[https://gitee.com/ascend/pytorch](https://gitee.com/ascend/pytorch)。
244
+ 链接:[https://gitcode.com/Ascend/pytorch](https://gitcode.com/Ascend/pytorch)。
243
245
 
244
246
  ## 3 安装 MindSpeed LLM
245
247