mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,17 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import copy
16
17
  import inspect
17
18
  import os
18
19
  from dataclasses import dataclass, is_dataclass
19
- from typing import Tuple, Dict, Optional, Any
20
20
  from functools import partial
21
- import copy
22
- from typing import Union
21
+ from typing import Tuple, Dict, Optional, Any, Union
23
22
 
24
23
  import numpy as np
25
24
 
26
25
  from msprobe.core.common.const import Const
26
+ from msprobe.core.common.file_utils import save_npy
27
27
  from msprobe.core.common.log import logger
28
28
  from msprobe.core.common.utils import convert_tuple, CompareException
29
29
 
@@ -79,21 +79,17 @@ class ModuleBackwardOutputs:
79
79
 
80
80
 
81
81
  class TensorStatInfo:
82
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
82
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
83
83
  self.max = max_val
84
84
  self.min = min_val
85
85
  self.mean = mean_val
86
86
  self.norm = norm_val
87
- self.stack_tensor_stat = stack_tensor_stat
88
87
 
89
88
 
90
89
  class BaseDataProcessor:
91
90
  _recursive_key_stack = []
92
- special_type = (
93
- np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
94
- bool, int, float, str, slice,
95
- type(Ellipsis)
96
- )
91
+ builtin_type = (bool, int, float, str, slice, type(Ellipsis))
92
+ np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray)
97
93
 
98
94
  def __init__(self, config, data_writer):
99
95
  self.data_writer = data_writer
@@ -120,7 +116,10 @@ class BaseDataProcessor:
120
116
  @staticmethod
121
117
  def analyze_api_call_stack(name):
122
118
  try:
123
- api_stack = inspect.stack()[5:]
119
+ if name.startswith("Primitive"):
120
+ api_stack = inspect.stack()[4:]
121
+ else:
122
+ api_stack = inspect.stack()[5:]
124
123
  except Exception as e:
125
124
  logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
126
125
  api_stack = None
@@ -129,12 +128,14 @@ class BaseDataProcessor:
129
128
  for (_, path, line, func, code, _) in api_stack:
130
129
  if not code:
131
130
  continue
131
+ if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \
132
+ Const.CALL_STACK_FLAG not in path:
133
+ continue
132
134
  stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
133
135
  stack_str.append(stack_line)
134
136
  else:
135
137
  stack_str.append(Const.WITHOUT_CALL_STACK)
136
- stack_info_struct = {name: stack_str}
137
- return stack_info_struct
138
+ return tuple(stack_str)
138
139
 
139
140
  @staticmethod
140
141
  def transfer_type(data):
@@ -178,20 +179,8 @@ class BaseDataProcessor:
178
179
  "invalid data_structure type or invalid index")
179
180
 
180
181
  @staticmethod
181
- def _convert_numpy_to_builtin(arg):
182
- type_mapping = {
183
- np.integer: int,
184
- np.floating: float,
185
- np.bool_: bool,
186
- np.complexfloating: complex,
187
- np.str_: str,
188
- np.byte: bytes,
189
- np.unicode_: str
190
- }
191
- for numpy_type, builtin_type in type_mapping.items():
192
- if isinstance(arg, numpy_type):
193
- return builtin_type(arg), type(arg).__name__
194
- return arg, ''
182
+ def is_distributed_op(module):
183
+ return getattr(module, "op_is_distributed", False)
195
184
 
196
185
  @staticmethod
197
186
  def _analyze_builtin(arg):
@@ -217,21 +206,40 @@ class BaseDataProcessor:
217
206
  return single_arg
218
207
 
219
208
  @staticmethod
220
- def _analyze_numpy(ndarray, numpy_type):
209
+ def _analyze_numpy(arg):
210
+ return {"type": type(arg).__name__, "value": arg.item()}
211
+
212
+ @staticmethod
213
+ def _analyze_ndarray(ndarray, _):
221
214
  ndarray_json = {}
222
215
  ndarray_json.update({'type': 'numpy.ndarray'})
223
216
  ndarray_json.update({'dtype': str(ndarray.dtype)})
224
217
  ndarray_json.update({'shape': ndarray.shape})
225
- if ndarray.size > 0:
226
- ndarray_json.update({"Max": np.max(ndarray).item()})
227
- ndarray_json.update({"Min": np.min(ndarray).item()})
228
- ndarray_json.update({"Mean": np.mean(ndarray).item()})
229
- ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
230
- else:
231
- ndarray_json.update({"Max": None})
232
- ndarray_json.update({"Min": None})
233
- ndarray_json.update({"Mean": None})
234
- ndarray_json.update({"Norm": None})
218
+
219
+ # 先初始化默认值
220
+ stats = {
221
+ "Max": None,
222
+ "Min": None,
223
+ "Mean": None,
224
+ "Norm": None
225
+ }
226
+
227
+ try:
228
+ # 只有非空时才尝试计算
229
+ if ndarray.size > 0:
230
+ stats = {
231
+ "Max": np.max(ndarray).item(),
232
+ "Min": np.min(ndarray).item(),
233
+ "Mean": np.mean(ndarray).item(),
234
+ "Norm": np.linalg.norm(ndarray).item()
235
+ }
236
+ except Exception as e:
237
+ # 决定打印内容或切片
238
+ logger.warning(f"Error analyzing ndarray stats: {e}")
239
+
240
+ # 最后一次性更新
241
+ ndarray_json.update(stats)
242
+
235
243
  return ndarray_json
236
244
 
237
245
  @staticmethod
@@ -248,12 +256,12 @@ class BaseDataProcessor:
248
256
 
249
257
  @classmethod
250
258
  def get_special_types(cls):
251
- return cls.special_type
259
+ return cls.builtin_type + cls.np_type
252
260
 
253
261
  @classmethod
254
262
  def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
255
- if depth > Const.MAX_DEPTH:
256
- logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
263
+ if depth > Const.DUMP_MAX_DEPTH:
264
+ logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
257
265
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
258
266
  if isinstance(args, cls.get_special_types()):
259
267
  arg_transform = transform(args, cls._recursive_key_stack)
@@ -303,6 +311,7 @@ class BaseDataProcessor:
303
311
 
304
312
  def real_hook_fn(grad):
305
313
  return wrap_hook_fn(grad)
314
+
306
315
  element.register_hook(real_hook_fn)
307
316
 
308
317
  def if_return_forward_new_output(self):
@@ -350,6 +359,8 @@ class BaseDataProcessor:
350
359
  return api_info_struct
351
360
 
352
361
  def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
362
+ if self.is_distributed_op(module):
363
+ module_input_output.update_output_with_args_and_kwargs()
353
364
  api_info_struct = {}
354
365
  # check whether data_mode contains forward or input
355
366
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
@@ -427,6 +438,7 @@ class BaseDataProcessor:
427
438
  api_info_struct = {}
428
439
  self.save_name = name + Const.SEP + param_name
429
440
  data_info = self.analyze_element(grad)
441
+ self.save_name = None
430
442
  grad_info_dict = {param_name: [data_info]}
431
443
  api_info_struct[name] = grad_info_dict
432
444
  return api_info_struct
@@ -435,10 +447,10 @@ class BaseDataProcessor:
435
447
  file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
436
448
  if self.save_name is not None:
437
449
  dump_data_name = (self.save_name + file_format)
438
- self.save_name = None
439
450
  else:
440
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
441
- suffix + file_format)
451
+ suffix_with_seq = (Const.SEP + suffix) if suffix else ""
452
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq +
453
+ file_format)
442
454
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
443
455
  return dump_data_name, file_path
444
456
 
@@ -447,23 +459,32 @@ class BaseDataProcessor:
447
459
 
448
460
  def analyze_debug_forward(self, variable, name_with_count):
449
461
  self.current_api_or_module_name = name_with_count
450
- self.api_data_category = Const.TENSOR
451
- # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
462
+ self.api_data_category = Const.DEBUG
463
+ # these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt
452
464
  data_info = self.analyze_element(variable)
453
465
  return data_info
454
466
 
455
- def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
467
+ def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure):
456
468
  def hook_fn(grad, indexes):
457
469
  suffix = Const.SEP.join([str(index) for index in indexes])
458
- self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
470
+ suffix_with_sep = (Const.SEP + suffix) if suffix else ""
471
+ self.save_name = grad_name_with_count_category + suffix_with_sep
459
472
  grad_data_info = self.analyze_element(grad)
460
473
  self.save_name = None
461
- full_index = [grad_name_with_count] + indexes
474
+ full_index = [grad_name_with_count_category] + indexes
462
475
  try:
463
476
  self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
464
477
  except (ValueError, IndexError) as e:
465
- logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
466
- f"skip current recording, detailed infomation: {e}")
478
+ logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable,"
479
+ f"skip current recording, detailed information: {e}")
467
480
  return grad
481
+
468
482
  wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
469
- self.recursive_apply_transform(variable, wrap_register_hook_single_element)
483
+ self.recursive_apply_transform(variable, wrap_register_hook_single_element)
484
+
485
+ def _analyze_and_save_ndarray(self, ndarray, suffix):
486
+ dump_data_name, file_path = self.get_save_file_path(suffix)
487
+ save_npy(ndarray, file_path)
488
+ ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix)
489
+ ndarray_json.update({"data_name": dump_data_name})
490
+ return ndarray_json
@@ -17,16 +17,17 @@ import zlib
17
17
 
18
18
  import mindspore as ms
19
19
  from mindspore import mint, ops, hal
20
+ from mindspore.mint import distributed
20
21
  from mindspore._c_expression.typing import Number
21
22
  import numpy as np
22
23
 
23
24
  from msprobe.core.common.const import Const
24
25
  from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
25
26
  ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
26
- from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy
27
+ from msprobe.core.common.file_utils import path_len_exceeds_limit
27
28
  from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
29
  from msprobe.mindspore.common.log import logger
29
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
30
31
 
31
32
  has_adump = True
32
33
  try:
@@ -36,7 +37,7 @@ except ImportError:
36
37
 
37
38
 
38
39
  class MindsporeDataProcessor(BaseDataProcessor):
39
- mindspore_special_type = tuple([ms.Tensor, Number])
40
+ mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp])
40
41
 
41
42
  def __init__(self, config, data_writer):
42
43
  super().__init__(config, data_writer)
@@ -44,6 +45,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
44
45
  "dtype": self.analyze_dtype_in_kwargs
45
46
  }
46
47
  self._async_dump_cache = {}
48
+ self.api_register = get_api_register()
47
49
 
48
50
  @staticmethod
49
51
  def get_md5_for_tensor(x):
@@ -64,7 +66,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
64
66
  tensor_stat.max = np.max(data_np).item()
65
67
  tensor_stat.min = np.min(data_np).item()
66
68
  elif not data.shape:
67
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
69
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
68
70
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
69
71
  data_abs = np.abs(data.asnumpy())
70
72
  tensor_stat.max = np.max(data_abs).item()
@@ -74,83 +76,98 @@ class MindsporeDataProcessor(BaseDataProcessor):
74
76
  else:
75
77
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
76
78
  data = data.to(ms.float32)
77
- api_register.norm_inner_op_set_ori_func()
78
- get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
79
- get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
80
- get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
81
- if hasattr(mint, "norm"):
82
- get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
83
- else:
84
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
85
- tensor_stat.max = get_max_value(data).item()
86
- tensor_stat.min = get_min_value(data).item()
87
- tensor_stat.mean = get_mean_value(data).item()
88
- tensor_stat.norm = get_norm_value(data).item()
89
- api_register.norm_inner_op_set_hook_func()
79
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
80
+ tensor_stat.max = mint.max(data)
81
+ tensor_stat.min = mint.min(data)
82
+ tensor_stat.mean = mint.mean(data)
83
+ tensor_stat.norm = get_norm_value(data)
90
84
  return tensor_stat
91
85
 
92
86
  @staticmethod
93
87
  def get_stat_info_async(data):
94
88
  tensor_stat = TensorStatInfo()
95
- stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
96
- if data.dtype == ms.complex64 or data.dtype == ms.complex128:
89
+ if data.dtype == ms.bool_:
90
+ tensor_stat.max = mint.any(data)
91
+ tensor_stat.min = mint.all(data)
92
+ elif not data.shape:
93
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
94
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
97
95
  logger.warning("Async dump do not support complex data!")
98
96
  return tensor_stat
99
- elif data.dtype == ms.bool_:
100
- tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
101
- elif not data.shape:
102
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
103
97
  else:
104
98
  if not ops.is_floating_point(data) or data.dtype == ms.float64:
105
99
  data = data.to(ms.float32)
106
- api_register.norm_inner_op_set_ori_func()
107
- get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
108
- get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
109
- get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
110
- if hasattr(mint, "norm"):
111
- get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
112
- else:
113
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
114
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
115
- [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
116
- api_register.norm_inner_op_set_hook_func()
100
+ get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm
101
+ tensor_stat.max = mint.max(data)
102
+ tensor_stat.min = mint.min(data)
103
+ tensor_stat.mean = mint.mean(data)
104
+ tensor_stat.norm = get_norm_value(data)
117
105
  return tensor_stat
118
106
 
119
107
  @staticmethod
120
108
  def is_hookable_element(element):
121
109
  return hasattr(element, "register_hook") and callable(element.register_hook)
122
110
 
111
+ @staticmethod
112
+ def process_group_hash(arg):
113
+ group_ranks = distributed.get_process_group_ranks(arg)
114
+ group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
115
+ return f"{group_ranks_hash:08x}"
116
+
123
117
  @classmethod
124
118
  def get_special_types(cls):
125
119
  return super().get_special_types() + cls.mindspore_special_type
126
120
 
121
+ def dump_async_data(self):
122
+ for file_path, tensor in self._async_dump_cache.items():
123
+ save_tensor_as_npy(tensor, file_path)
124
+ self._async_dump_cache.clear()
125
+
127
126
  def get_stat_info(self, data):
127
+ self.api_register.restore_inner_used_api()
128
128
  tensor_stat = TensorStatInfo()
129
129
  if data.numel() == 0:
130
- return tensor_stat
130
+ stat_info = tensor_stat
131
131
  else:
132
132
  if self.config.async_dump:
133
- return MindsporeDataProcessor.get_stat_info_async(data)
133
+ stat_info = MindsporeDataProcessor.get_stat_info_async(data)
134
134
  else:
135
- return MindsporeDataProcessor.get_stat_info_sync(data)
135
+ stat_info = MindsporeDataProcessor.get_stat_info_sync(data)
136
+ self.api_register.register_inner_used_api()
137
+ return stat_info
136
138
 
137
139
  def analyze_single_element(self, element, suffix_stack):
138
140
  if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
139
141
  return self.mindspore_object_key[suffix_stack[-1]](element)
140
142
 
141
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
142
- if converted_numpy is not element:
143
- return {"type": numpy_type, "value": converted_numpy}
144
- if isinstance(element, Number):
145
- return self.analyze_dtype_in_kwargs(element)
146
- if isinstance(element, ms.Tensor):
147
- return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
148
- if isinstance(element, np.ndarray):
149
- return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
150
- if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
151
- return self._analyze_builtin(element)
143
+ suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
144
+ type_analyzer = [
145
+ (MindsporeDataProcessor.builtin_type, self._analyze_builtin),
146
+ (ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
147
+ (Number, self.analyze_dtype_in_kwargs),
148
+ (MindsporeDataProcessor.np_type[:-1], self._analyze_numpy),
149
+ (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
150
+ (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str))
151
+ ]
152
+ for type_key, analyze_fn in type_analyzer:
153
+ if isinstance(element, type_key):
154
+ return analyze_fn(element)
152
155
  return {}
153
156
 
157
+ def _analyze_p2pop(self, arg, suffix):
158
+ p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"}
159
+ try:
160
+ tensor_info = self._analyze_tensor(arg.tensor, suffix)
161
+ p2pop_info.update({"tensor": tensor_info})
162
+ p2pop_info.update({"op": arg.op})
163
+ p2pop_info.update({"peer": arg.peer})
164
+ p2pop_info.update({"tag": arg.tag})
165
+ group_id = self.process_group_hash(arg.group) if arg.group else None
166
+ p2pop_info.update({"group_id": group_id})
167
+ except Exception as e:
168
+ logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
169
+ return p2pop_info
170
+
154
171
  def _analyze_tensor(self, tensor, suffix):
155
172
  tensor_stat = self.get_stat_info(tensor)
156
173
  tensor_json = {
@@ -159,45 +176,54 @@ class MindsporeDataProcessor(BaseDataProcessor):
159
176
  'shape': tensor.shape
160
177
  }
161
178
 
162
- if tensor_stat.stack_tensor_stat is None:
163
- tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
164
- tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
165
- tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
166
- tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
167
- else:
168
- tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
179
+ # 将统计值存入全局 buffer,并返回占位索引
180
+ stat_values = [
181
+ tensor_stat.max,
182
+ tensor_stat.min,
183
+ tensor_stat.mean,
184
+ tensor_stat.norm
185
+ ]
186
+
187
+ placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
188
+
189
+ tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
190
+
169
191
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
170
192
  tensor_md5 = self.get_md5_for_tensor(tensor)
171
193
  tensor_json.update({Const.MD5: tensor_md5})
172
194
  return tensor_json
173
195
 
174
-
175
- class StatisticsDataProcessor(MindsporeDataProcessor):
176
- pass
177
-
178
-
179
- class TensorDataProcessor(MindsporeDataProcessor):
180
- def dump_async_data(self):
181
- for file_path, tensor in self._async_dump_cache.items():
182
- save_tensor_as_npy(tensor, file_path)
183
- self._async_dump_cache.clear()
184
-
185
- def _analyze_tensor(self, tensor, suffix):
196
+ def _analyze_and_save_tensor(self, tensor, suffix):
186
197
  dump_data_name, file_path = self.get_save_file_path(suffix)
187
- single_arg = super()._analyze_tensor(tensor, suffix)
198
+ single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix)
188
199
  single_arg.update({"data_name": dump_data_name})
189
200
  if self.config.async_dump:
190
201
  self._async_dump_cache[file_path] = tensor.copy()
191
202
  else:
192
203
  save_tensor_as_npy(tensor, file_path)
193
204
  return single_arg
194
-
195
- def _analyze_numpy(self, ndarray, suffix):
196
- dump_data_name, file_path = self.get_save_file_path(suffix)
197
- save_npy(ndarray, file_path)
198
- ndarray_json = super()._analyze_numpy(ndarray, suffix)
199
- ndarray_json.update({"data_name": dump_data_name})
200
- return ndarray_json
205
+
206
+
207
+ class StatisticsDataProcessor(MindsporeDataProcessor):
208
+ def _analyze_tensor(self, tensor, suffix):
209
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
210
+ return self._analyze_and_save_tensor(tensor, suffix)
211
+ else:
212
+ return super()._analyze_tensor(tensor, suffix)
213
+
214
+ def _analyze_ndarray(self, ndarray, suffix):
215
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
216
+ return self._analyze_and_save_ndarray(ndarray, suffix)
217
+ else:
218
+ return super()._analyze_ndarray(ndarray, suffix)
219
+
220
+
221
+ class TensorDataProcessor(MindsporeDataProcessor):
222
+ def _analyze_tensor(self, tensor, suffix):
223
+ return self._analyze_and_save_tensor(tensor, suffix)
224
+
225
+ def _analyze_ndarray(self, ndarray, suffix):
226
+ return self._analyze_and_save_ndarray(ndarray, suffix)
201
227
 
202
228
 
203
229
  class OverflowCheckDataProcessor(MindsporeDataProcessor):
@@ -262,11 +288,26 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
262
288
  self.cached_tensors_and_file_paths = {}
263
289
 
264
290
  def _analyze_maybe_overflow_tensor(self, tensor_json):
265
- if tensor_json['Max'] is None:
291
+ tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
292
+ if tensor_stat_index is None:
293
+ logger.warning("tensor_stat_index does not exist in tensor_json.")
294
+ return
295
+ max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
296
+ min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
297
+ if max_tensor is None or min_tensor is None:
266
298
  return
267
- if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
299
+
300
+ def check_inf_nan(value):
301
+ # Use .item() if it's a tensor-like structure
302
+ if hasattr(value, "item"):
303
+ value = value.item()
304
+ return np.isinf(value) or np.isnan(value)
305
+
306
+ if check_inf_nan(max_tensor):
268
307
  self.has_overflow = True
269
- if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
308
+ return
309
+
310
+ if check_inf_nan(min_tensor):
270
311
  self.has_overflow = True
271
312
 
272
313
  def _analyze_tensor(self, tensor, suffix):