mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import hashlib
17
16
  import zlib
18
17
  from dataclasses import asdict
19
18
  from typing import List
@@ -24,14 +23,15 @@ from torch import distributed as dist
24
23
  from torch.distributed.distributed_c10d import _get_default_group
25
24
 
26
25
  from msprobe.core.common.const import Const
26
+ from msprobe.core.common.exceptions import MsprobeException
27
27
  from msprobe.core.common.file_utils import path_len_exceeds_limit
28
28
  from msprobe.core.common.log import logger
29
29
  from msprobe.core.common.utils import convert_tuple
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
31
32
  ModuleForwardInputsOutputs, TensorStatInfo
32
- from msprobe.pytorch.common.utils import save_pt, load_pt
33
+ from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor
33
34
  from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
34
- from msprobe.core.common.utils import recursion_depth_decorator
35
35
 
36
36
  is_gpu = False
37
37
  try:
@@ -78,14 +78,16 @@ class PytorchDataProcessor(BaseDataProcessor):
78
78
  def analyze_device_in_kwargs(element):
79
79
  single_arg = {}
80
80
  single_arg.update({'type': "torch.device"})
81
- if not isinstance(element, str):
81
+ if isinstance(element, (int, str)):
82
+ single_arg.update({"value": element})
83
+ elif isinstance(element, torch.device):
82
84
  if hasattr(element, "index"):
83
85
  device_value = element.type + ":" + str(element.index)
84
86
  else:
85
87
  device_value = element.type
86
88
  single_arg.update({"value": device_value})
87
89
  else:
88
- single_arg.update({"value": element})
90
+ logger.debug(f"Device type {type(element)} is not supported.")
89
91
  return single_arg
90
92
 
91
93
  @staticmethod
@@ -99,19 +101,17 @@ class PytorchDataProcessor(BaseDataProcessor):
99
101
  logger.warning("Async dump do not support complex data!")
100
102
  return tensor_stat
101
103
  elif data.dtype == torch.bool:
102
- tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
103
- [torch.any(data), torch.all(data)]))
104
+ tensor_stat.max = torch.any(data)
105
+ tensor_stat.min = torch.all(data)
104
106
  elif not data.shape:
105
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
107
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
106
108
  else:
107
- if not data.is_floating_point() or data.dtype == torch.float64:
109
+ if data.dtype == torch.float64 or not data.is_floating_point():
108
110
  data = data.float()
109
- tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
110
- torch.max(data),
111
- torch.min(data),
112
- torch.mean(data),
113
- torch.norm(data)
114
- ]))
111
+ tensor_stat.max = torch.max(data)
112
+ tensor_stat.min = torch.min(data)
113
+ tensor_stat.mean = torch.mean(data)
114
+ tensor_stat.norm = torch.norm(data)
115
115
  return tensor_stat
116
116
 
117
117
  @staticmethod
@@ -124,17 +124,17 @@ class PytorchDataProcessor(BaseDataProcessor):
124
124
  tensor_stat.min = np.min(data_abs).item()
125
125
  tensor_stat.mean = np.mean(data_abs).item()
126
126
  elif data.dtype == torch.bool:
127
- tensor_stat.max = torch.any(data).item()
128
- tensor_stat.min = torch.all(data).item()
127
+ tensor_stat.max = torch.any(data)
128
+ tensor_stat.min = torch.all(data)
129
129
  elif not data.shape:
130
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
130
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data
131
131
  else:
132
- if not data.is_floating_point() or data.dtype == torch.float64:
132
+ if data.dtype == torch.float64 or not data.is_floating_point():
133
133
  data = data.float()
134
- tensor_stat.max = torch.max(data).item()
135
- tensor_stat.min = torch.min(data).item()
136
- tensor_stat.mean = torch.mean(data).item()
137
- tensor_stat.norm = torch.norm(data).item()
134
+ tensor_stat.max = torch.max(data)
135
+ tensor_stat.min = torch.min(data)
136
+ tensor_stat.mean = torch.mean(data)
137
+ tensor_stat.norm = torch.norm(data)
138
138
  return tensor_stat
139
139
 
140
140
  @staticmethod
@@ -143,7 +143,7 @@ class PytorchDataProcessor(BaseDataProcessor):
143
143
  if data.is_meta:
144
144
  return tensor_stat
145
145
  data_clone = data.detach()
146
- if data_clone.numel() == 0:
146
+ if not data_clone.numel() or not data_clone.data_ptr():
147
147
  return tensor_stat
148
148
  else:
149
149
  if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
@@ -171,12 +171,8 @@ class PytorchDataProcessor(BaseDataProcessor):
171
171
  @staticmethod
172
172
  def process_group_hash(arg):
173
173
  group_ranks = dist.get_process_group_ranks(arg)
174
- group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
175
- return group_ranks_hash
176
-
177
- @staticmethod
178
- def is_distributed_op(module):
179
- return getattr(module, "op_is_distributed", False)
174
+ group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8'))
175
+ return f"{group_ranks_hash:08x}"
180
176
 
181
177
  @staticmethod
182
178
  def is_hookable_element(element):
@@ -214,43 +210,52 @@ class PytorchDataProcessor(BaseDataProcessor):
214
210
  logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
215
211
  return {"type": "torch.distributed.ReduceOp", "value": op_type}
216
212
 
213
+ @staticmethod
214
+ def _cast_to_float_if_fp8(tensor):
215
+ dtype = str(tensor.dtype)
216
+ if is_float8_tensor(tensor):
217
+ dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype
218
+ logger.debug(
219
+ f"The {dtype} tensor analyzing/saving is unsupported in dump function."
220
+ f"Casting to float for processing."
221
+ )
222
+ tensor = tensor.float()
223
+ return tensor, dtype
224
+
217
225
  @classmethod
218
226
  def get_special_types(cls):
219
227
  return super().get_special_types() + cls.pytorch_special_type
220
228
 
229
+ def dump_async_data(self):
230
+ for file_path, tensor in self._async_dump_cache.items():
231
+ save_pt(tensor.contiguous(), file_path)
232
+ self._async_dump_cache.clear()
233
+
221
234
  def analyze_single_element(self, element, suffix_stack):
222
235
  if suffix_stack and suffix_stack[-1] in self.torch_object_key:
223
236
  return self.torch_object_key[suffix_stack[-1]](element)
224
- if isinstance(element, torch.Size):
225
- return self._analyze_torch_size(element)
226
- if isinstance(element, torch.memory_format):
227
- return self._analyze_memory_format(element)
228
- if isinstance(element, dist.ProcessGroup):
229
- return self._analyze_process_group(element)
230
- if isinstance(element, dist.P2POp):
231
- return self._analyze_p2pop(element)
232
- if isinstance(element, dist.ReduceOp):
233
- return self._analyze_reduce_op(element)
234
- converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
235
- if converted_numpy is not element:
236
- return {"type": numpy_type, "value": converted_numpy}
237
- if isinstance(element, torch.Tensor):
238
- return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
239
- if isinstance(element, np.ndarray):
240
- return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
241
- if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
242
- return self._analyze_builtin(element)
243
- return {}
244
237
 
245
- def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
246
- if self.is_distributed_op(module):
247
- module_input_output.update_output_with_args_and_kwargs()
248
- return super().analyze_forward_output(name, module, module_input_output)
238
+ suffix_str = Const.SEP.join(str(s) for s in suffix_stack)
239
+ type_analyzer = [
240
+ (PytorchDataProcessor.builtin_type, self._analyze_builtin),
241
+ (torch.Size, self._analyze_torch_size),
242
+ (torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)),
243
+ (torch.memory_format, self._analyze_memory_format),
244
+ (dist.ProcessGroup, self._analyze_process_group),
245
+ (dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)),
246
+ (dist.ReduceOp, self._analyze_reduce_op),
247
+ (PytorchDataProcessor.np_type[:-1], self._analyze_numpy),
248
+ (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)),
249
+ ]
250
+ for type_key, analyze_fn in type_analyzer:
251
+ if isinstance(element, type_key):
252
+ return analyze_fn(element)
253
+ return {}
249
254
 
250
- def _analyze_p2pop(self, arg):
255
+ def _analyze_p2pop(self, arg, suffix):
251
256
  p2pop_info = {"class_type": "torch.distributed.P2POp"}
252
257
  try:
253
- tensor_info = self._analyze_tensor(arg.tensor, [])
258
+ tensor_info = self._analyze_tensor(arg.tensor, suffix)
254
259
  p2pop_info.update({"tensor": tensor_info})
255
260
  p2pop_info.update({"op": arg.op.__name__})
256
261
  p2pop_info.update({"peer": arg.peer})
@@ -263,63 +268,71 @@ class PytorchDataProcessor(BaseDataProcessor):
263
268
  return p2pop_info
264
269
 
265
270
  def _analyze_tensor(self, tensor, suffix):
271
+ tensor, dtype = self._cast_to_float_if_fp8(tensor)
266
272
  tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
267
273
  tensor_json = {}
268
274
  tensor_json.update({'type': 'torch.Tensor'})
269
- tensor_json.update({'dtype': str(tensor.dtype)})
275
+ tensor_json.update({'dtype': dtype})
270
276
  tensor_json.update({"shape": tensor.shape})
271
- if tensor_stat.stack_tensor_stat is None:
272
- tensor_json.update({"Max": tensor_stat.max})
273
- tensor_json.update({"Min": tensor_stat.min})
274
- tensor_json.update({"Mean": tensor_stat.mean})
275
- tensor_json.update({"Norm": tensor_stat.norm})
276
- tensor_json.update({"requires_grad": tensor.requires_grad})
277
- if tensor_stat.max is not None:
278
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
279
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
280
- if tensor_stat.min is not None:
281
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
282
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
283
277
 
284
- else:
285
- tensor_json.update({"requires_grad": tensor.requires_grad})
286
- tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
278
+ stat_values = [
279
+ tensor_stat.max,
280
+ tensor_stat.min,
281
+ tensor_stat.mean,
282
+ tensor_stat.norm
283
+ ]
284
+ placeholder_index = self.data_writer.append_stat_to_buffer(stat_values)
285
+
286
+ tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index})
287
+ tensor_json.update({"requires_grad": tensor.requires_grad})
287
288
 
288
289
  if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
289
290
  tensor_md5 = self.get_md5_for_tensor(tensor)
290
291
  tensor_json.update({Const.MD5: tensor_md5})
291
292
  return tensor_json
292
293
 
293
-
294
- class StatisticsDataProcessor(PytorchDataProcessor):
295
- pass
296
-
297
-
298
- class TensorDataProcessor(PytorchDataProcessor):
299
- def dump_async_data(self):
300
- for file_path, tensor in self._async_dump_cache.items():
301
- save_pt(tensor.contiguous(), file_path)
302
- self._async_dump_cache.clear()
303
-
304
- def _analyze_tensor(self, tensor, suffix):
294
+ def _analyze_and_save_tensor(self, tensor, suffix):
305
295
  dump_data_name, file_path = self.get_save_file_path(suffix)
306
- single_arg = super()._analyze_tensor(tensor, suffix)
296
+ single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix)
307
297
  single_arg.update({"data_name": dump_data_name})
298
+ tensor, _ = self._cast_to_float_if_fp8(tensor)
308
299
  if self.config.async_dump:
309
300
  self._async_dump_cache[file_path] = tensor.clone().detach()
310
301
  else:
311
302
  saved_tensor = tensor.clone().contiguous().detach()
312
303
  save_pt(saved_tensor, file_path)
313
304
  return single_arg
314
-
315
- def _analyze_numpy(self, ndarray, suffix):
305
+
306
+ def _analyze_and_save_ndarray(self, ndarray, suffix):
316
307
  dump_data_name, file_path = self.get_save_file_path(suffix)
317
308
  save_pt(torch.tensor(ndarray), file_path)
318
- ndarray_json = super()._analyze_numpy(ndarray, suffix)
309
+ ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix)
319
310
  ndarray_json.update({"data_name": dump_data_name})
320
311
  return ndarray_json
321
312
 
322
313
 
314
+ class StatisticsDataProcessor(PytorchDataProcessor):
315
+ def _analyze_tensor(self, tensor, suffix):
316
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
317
+ return self._analyze_and_save_tensor(tensor, suffix)
318
+ else:
319
+ return super()._analyze_tensor(tensor, suffix)
320
+
321
+ def _analyze_ndarray(self, ndarray, suffix):
322
+ if any(item in self.current_api_or_module_name for item in self.config.tensor_list):
323
+ return self._analyze_and_save_ndarray(ndarray, suffix)
324
+ else:
325
+ return super()._analyze_ndarray(ndarray, suffix)
326
+
327
+
328
+ class TensorDataProcessor(PytorchDataProcessor):
329
+ def _analyze_tensor(self, tensor, suffix):
330
+ return self._analyze_and_save_tensor(tensor, suffix)
331
+
332
+ def _analyze_ndarray(self, ndarray, suffix):
333
+ return self._analyze_and_save_ndarray(ndarray, suffix)
334
+
335
+
323
336
  class OverflowCheckDataProcessor(PytorchDataProcessor):
324
337
  __slots__ = ["cached_tensors_and_file_paths"]
325
338
 
@@ -383,7 +396,8 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
383
396
  self._analyze_maybe_overflow_flag()
384
397
  if self.has_overflow:
385
398
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
386
- save_pt(tensor, file_path)
399
+ tensor, _ = self._cast_to_float_if_fp8(tensor)
400
+ save_pt(tensor.clone().contiguous().detach(), file_path)
387
401
  self.real_overflow_nums += 1
388
402
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
389
403
  logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
@@ -409,10 +423,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
409
423
  raise RuntimeError(f"overflow check failed") from e
410
424
 
411
425
  def _analyze_maybe_overflow_tensor(self, tensor_json):
412
- if tensor_json['Max'] is None or tensor_json['Min'] is None:
426
+ tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX)
427
+ if tensor_stat_index is None:
428
+ logger.warning("tensor_stat_index does not exist in tensor_json.")
429
+ return
430
+ max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index)
431
+ min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index)
432
+
433
+ if max_tensor is None or min_tensor is None:
434
+ return
435
+
436
+ if torch.isinf(max_tensor) or torch.isnan(max_tensor):
437
+ self.has_overflow = True
413
438
  return
414
- self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
415
- np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
439
+
440
+ if torch.isinf(min_tensor) or torch.isnan(min_tensor):
441
+ self.has_overflow = True
416
442
 
417
443
  def _analyze_tensor(self, tensor, suffix):
418
444
  dump_data_name, file_path = self.get_save_file_path(suffix)
@@ -508,11 +534,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
508
534
  return
509
535
 
510
536
  if self.config.is_backward_kernel_dump:
511
- self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
512
- self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
513
537
  try:
538
+ self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
539
+ self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
514
540
  output = module.forward(*self.forward_args, **self.forward_kwargs)
515
- except Exception:
541
+ except Exception as e:
542
+ if isinstance(e, MsprobeException):
543
+ logger.warning(str(e))
516
544
  self._print_unsupported_log(name)
517
545
  self.enable_kernel_dump = False
518
546
  return
@@ -554,9 +582,17 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
554
582
  self.stop_kernel_dump()
555
583
  logger.info(f"The kernel data of {name} is dumped successfully.")
556
584
 
557
- @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
585
+ @recursion_depth_decorator(
586
+ "KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor",
587
+ max_depth=Const.DUMP_MAX_DEPTH
588
+ )
558
589
  def clone_and_detach_tensor(self, input_params):
559
590
  if isinstance(input_params, torch.Tensor):
591
+ if is_float8_tensor(input_params):
592
+ raise MsprobeException(
593
+ MsprobeException.UNSUPPORTED_TYPE_ERROR,
594
+ f"L2 backward dump does not support float8 type."
595
+ )
560
596
  if input_params.requires_grad:
561
597
  return input_params.clone().detach().requires_grad_()
562
598
  return input_params.clone()
@@ -571,6 +607,8 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
571
607
 
572
608
  def analyze_single_element(self, element, suffix_stack):
573
609
  if isinstance(element, torch.Tensor):
610
+ if is_float8_tensor(element):
611
+ return {}
574
612
  if not self.is_found_output_tensor:
575
613
  if element.requires_grad:
576
614
  self.forward_output_tensor = element
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,12 +16,14 @@
16
16
  import csv
17
17
  import os
18
18
  import copy
19
- import numpy as np
19
+ import threading
20
20
 
21
21
  from msprobe.core.common.const import Const, FileCheckConst
22
22
  from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
23
23
  from msprobe.core.common.log import logger
24
- from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.common.decorator import recursion_depth_decorator
25
+
26
+ lock = threading.Lock()
25
27
 
26
28
 
27
29
  class DataWriter:
@@ -34,10 +36,12 @@ class DataWriter:
34
36
  self.dump_tensor_data_dir = None
35
37
  self.debug_file_path = None
36
38
  self.flush_size = 1000
39
+ self.larger_flush_size = 20000
37
40
  self.cache_data = {}
38
41
  self.cache_stack = {}
39
42
  self.cache_construct = {}
40
43
  self.cache_debug = {}
44
+ self.stat_stack_list = []
41
45
 
42
46
  @staticmethod
43
47
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -54,13 +58,54 @@ class DataWriter:
54
58
  if is_new_file:
55
59
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
56
60
 
61
+ @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
62
+ def _replace_stat_placeholders(self, data, stat_result):
63
+ if isinstance(data, dict):
64
+ keys = list(data.keys()) # 获取当前所有键
65
+ for key in keys: # 递归所有变量
66
+ value = data[key]
67
+ if key == Const.TENSOR_STAT_INDEX and isinstance(value, int):
68
+ if value >= 0:
69
+ idx = value
70
+ else:
71
+ return
72
+ stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4
73
+
74
+ new_entries = {
75
+ Const.TYPE: data["type"],
76
+ Const.DTYPE: data["dtype"],
77
+ Const.SHAPE: data["shape"],
78
+ Const.MAX: stat_values[0],
79
+ Const.MIN: stat_values[1],
80
+ Const.MEAN: stat_values[2],
81
+ Const.NORM: stat_values[3],
82
+ }
83
+ del data[key]
84
+
85
+ # 重构字典顺序
86
+ updated_dict = {}
87
+ # 通过插入排序后字段保证字段写入json的有序
88
+ updated_dict.update(new_entries)
89
+ # 遍历原字典其他字段(排除已删除的tensor_stat_index)
90
+ for k in data:
91
+ if k not in new_entries:
92
+ updated_dict[k] = data[k]
93
+ data.clear()
94
+ data.update(updated_dict)
95
+ else:
96
+ self._replace_stat_placeholders(value, stat_result)
97
+ elif isinstance(data, (list, tuple)):
98
+ for item in data:
99
+ self._replace_stat_placeholders(item, stat_result)
100
+
57
101
  def reset_cache(self):
58
102
  self.cache_data = {}
59
103
  self.cache_stack = {}
60
104
  self.cache_construct = {}
105
+ self.cache_debug = {}
61
106
 
62
107
  def initialize_json_file(self, **kwargs):
63
- if self.debug_file_path and not self.cache_debug:
108
+ if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
64
109
  # debug level case only create debug.json
65
110
  debug_dict = copy.deepcopy(kwargs)
66
111
  debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
@@ -86,39 +131,59 @@ class DataWriter:
86
131
 
87
132
  def flush_data_periodically(self):
88
133
  dump_data = self.cache_data.get(Const.DATA)
89
- if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
90
- self.write_json()
91
134
 
92
- def update_data(self, new_data):
93
- if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
94
- logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
95
- return
96
- dump_data = self.cache_data.get(Const.DATA)
97
- if not isinstance(dump_data, dict):
98
- logger.warning(f"The dump data({dump_data}) should be a dict.")
135
+ if not dump_data or not isinstance(dump_data, dict):
99
136
  return
100
137
 
101
- key = next(iter(new_data.keys()))
102
- if key in dump_data:
103
- dump_data.get(key).update(new_data.get(key))
104
- else:
105
- dump_data.update(new_data)
138
+ length = len(dump_data)
106
139
 
107
- def update_stack(self, new_data):
108
- self.cache_stack.update(new_data)
140
+ threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
141
+
142
+ if length % threshold == 0:
143
+ self.write_json()
144
+
145
+ def update_data(self, new_data):
146
+ with lock:
147
+ if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
148
+ logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
149
+ return
150
+ dump_data = self.cache_data.get(Const.DATA)
151
+ if not isinstance(dump_data, dict):
152
+ logger.warning(f"The dump data({dump_data}) should be a dict.")
153
+ return
154
+
155
+ key = next(iter(new_data.keys()))
156
+ if key in dump_data:
157
+ dump_data.get(key).update(new_data.get(key))
158
+ else:
159
+ dump_data.update(new_data)
160
+
161
+ def update_stack(self, name, stack_data):
162
+ with lock:
163
+ api_list = self.cache_stack.get(stack_data)
164
+ if api_list is None:
165
+ self.cache_stack.update({stack_data: [name]})
166
+ else:
167
+ api_list.append(name)
109
168
 
110
169
  def update_construct(self, new_data):
111
- self.cache_construct.update(new_data)
170
+ with lock:
171
+ self.cache_construct.update(new_data)
112
172
 
113
173
  def update_debug(self, new_data):
114
- self.cache_debug['data'].update(new_data)
174
+ with lock:
175
+ self.cache_debug['data'].update(new_data)
115
176
 
116
177
  def write_data_json(self, file_path):
117
178
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
118
179
  save_json(file_path, self.cache_data, indent=1)
119
180
 
120
181
  def write_stack_info_json(self, file_path):
121
- save_json(file_path, self.cache_stack, indent=1)
182
+ num, new_cache_stack = 0, {}
183
+ for key, value in self.cache_stack.items():
184
+ new_cache_stack[num] = [value, key]
185
+ num += 1
186
+ save_json(file_path, new_cache_stack, indent=1)
122
187
 
123
188
  def write_construct_info_json(self, file_path):
124
189
  save_json(file_path, self.cache_construct, indent=1)
@@ -126,38 +191,62 @@ class DataWriter:
126
191
  def write_debug_info_json(self, file_path):
127
192
  save_json(file_path, self.cache_debug, indent=1)
128
193
 
194
+ def append_stat_to_buffer(self, stat_vector):
195
+ """
196
+ 直接使用 Python list 存储 stat_vector,
197
+ 将 stat_vector 存入 self.stat_stack_list 的方式
198
+ """
199
+ self.stat_stack_list.append(stat_vector)
200
+ return len(self.stat_stack_list) - 1
201
+
202
+ def get_buffer_values_max(self, index):
203
+ if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
204
+ return self.stat_stack_list[index][0]
205
+ else:
206
+ logger.warning(f"stat_stack_list[{index}] The internal data is incomplete,"
207
+ f" and the maximum value cannot be obtained.")
208
+ return None
209
+
210
+ def get_buffer_values_min(self, index):
211
+ if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
212
+ return self.stat_stack_list[index][1]
213
+ else:
214
+ logger.warning(f"stat_stack_list[{index}] Internal data is incomplete"
215
+ f" and minimum values cannot be obtained.")
216
+ return None
217
+
218
+ def flush_stat_stack(self):
219
+ """
220
+ 在 flush 阶段,将所有存储的统计值从设备搬到 CPU,
221
+ 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表
222
+ """
223
+ if not self.stat_stack_list:
224
+ return []
225
+ result = [
226
+ [
227
+ x.item() if hasattr(x, "item") else x
228
+ for x in stat_values
229
+ ]
230
+ for stat_values in self.stat_stack_list
231
+ ]
232
+ self.stat_stack_list = []
233
+ return result
234
+
129
235
  def write_json(self):
130
- if self.cache_data:
131
- self.write_data_json(self.dump_file_path)
132
- if self.cache_stack:
133
- self.write_stack_info_json(self.stack_file_path)
134
- if self.cache_construct:
135
- self.write_construct_info_json(self.construct_file_path)
136
- if self.cache_debug:
137
- self.write_debug_info_json(self.debug_file_path)
138
-
139
- def fill_stack_tensor_data(self):
140
- self.process_stat_data_recursive(self.cache_data)
141
-
142
- def process_stat_data_recursive(self, data, depth=0):
143
- if depth > Const.MAX_DEPTH:
144
- logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
145
- raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
146
- if isinstance(data, dict):
147
- if "tensor_stat" in data.keys():
148
- tensor_stat = data["tensor_stat"]
149
- if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
150
- logger.warning("Some bad data in async dump")
151
- else:
152
- tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
153
- if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
154
- tensor_stat_data = tensor_stat_data.cpu()
155
- for index, stat in zip(tensor_stat_index, tensor_stat_data):
156
- data.update({index: stat.item()})
157
- del data["tensor_stat"]
158
- else:
159
- for key in data.keys():
160
- self.process_stat_data_recursive(data[key], depth + 1)
161
- elif isinstance(data, (list, tuple)):
162
- for i in data:
163
- self.process_stat_data_recursive(i, depth + 1)
236
+ with lock:
237
+ # 在写 JSON 前,统一获取统计值
238
+ stat_result = self.flush_stat_stack()
239
+ # 遍历 cache_data,将占位符替换为最终统计值
240
+ if stat_result:
241
+ self._replace_stat_placeholders(self.cache_data, stat_result)
242
+ if self.cache_debug:
243
+ self._replace_stat_placeholders(self.cache_debug, stat_result)
244
+ if self.cache_data:
245
+ self.write_data_json(self.dump_file_path)
246
+ if self.cache_stack:
247
+ self.write_stack_info_json(self.stack_file_path)
248
+ if self.cache_construct:
249
+ self.write_construct_info_json(self.construct_file_path)
250
+ if self.cache_debug:
251
+ self.write_debug_info_json(self.debug_file_path)
252
+