mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,7 +14,11 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
- from msprobe.core.common.file_utils import load_yaml
17
+ import importlib
18
+ import inspect
19
+
20
+ from msprobe.core.common.file_utils import load_yaml, check_link
21
+ from msprobe.core.common.log import logger
18
22
 
19
23
 
20
24
  def get_ops():
@@ -26,3 +30,25 @@ def get_ops():
26
30
  wrap_torch = ops.get('torch')
27
31
  wrap_npu_ops = ops.get('torch_npu')
28
32
  return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops)
33
+
34
+
35
+ def dynamic_import_op(package, white_list):
36
+ package_name = package.__name__
37
+ ops = {}
38
+ ops_dir, _ = os.path.split(package.__file__)
39
+ check_link(ops_dir)
40
+ for file_name in os.listdir(ops_dir):
41
+ if file_name in white_list:
42
+ sub_module_name = file_name[:-3]
43
+ module_name = f"{package_name}.{sub_module_name}"
44
+ try:
45
+ module = importlib.import_module(module_name)
46
+ except Exception as e:
47
+ logger.warning(f"import {module_name} failed!")
48
+ continue
49
+
50
+ func_members = inspect.getmembers(module, inspect.isfunction)
51
+ for func_member in func_members:
52
+ func_name, func = func_member[0], func_member[1]
53
+ ops[f"{sub_module_name}.{func_name}"] = func
54
+ return ops
@@ -22,13 +22,18 @@ from torch.utils.tensorboard import SummaryWriter
22
22
  from tqdm import tqdm
23
23
 
24
24
  from msprobe.core.common.const import MonitorConst
25
- from msprobe.core.common.file_utils import read_csv, create_directory, remove_path
25
+ from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod
26
26
  from msprobe.core.common.utils import is_int
27
+ from msprobe.core.common.decorator import recursion_depth_decorator
27
28
  from msprobe.pytorch.common.log import logger
28
29
  from msprobe.pytorch.monitor.utils import get_target_output_dir
29
30
 
30
- all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]
31
+ all_data_type_list = [
32
+ "actv", "actv_grad", "exp_avg", "exp_avg_sq",
33
+ "grad_unreduced", "grad_reduced", "param_origin", "param_updated"
34
+ ]
31
35
  CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
36
+ MAX_PROCESS_NUM = 128
32
37
 
33
38
 
34
39
  def parse_step_line(line, ops):
@@ -46,7 +51,7 @@ def parse_step_line(line, ops):
46
51
 
47
52
  def parse_step_fn(filepath):
48
53
  data = read_csv(filepath)
49
- ops = [k for k in data.keys() if k in MonitorConst.OP_LIST]
54
+ ops = [k for k in data.keys() if k in MonitorConst.OP_LIST[:-2]]
50
55
  parse_step_result = {}
51
56
 
52
57
  for _, line in data.iterrows():
@@ -74,8 +79,10 @@ def write_step(output_dirpath, parse_step_result, rank, data_type):
74
79
  for op, value in ops.items():
75
80
  tag = f"{vpp_name}/{op}"
76
81
  writer.add_scalar(tag, value, step)
82
+ writer.flush()
77
83
 
78
84
 
85
+ @recursion_depth_decorator("update_dict", max_depth=50)
79
86
  def update_dict(dict1, dict2):
80
87
  for key, value in dict2.items():
81
88
  if key in dict1:
@@ -115,11 +122,13 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
115
122
  def check_process_num(process_num):
116
123
  if not is_int(process_num) or process_num <= 0:
117
124
  raise ValueError(f"process_num({process_num}) is not a positive integer")
125
+ if process_num > MAX_PROCESS_NUM:
126
+ raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.")
118
127
 
119
128
 
120
129
  def check_data_type_list(data_type_list):
121
130
  if data_type_list is None:
122
- logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}")
131
+ logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}")
123
132
  return
124
133
  if not isinstance(data_type_list, list):
125
134
  raise ValueError(f"data_type_list({data_type_list}) is not a list")
@@ -161,4 +170,5 @@ def csv2tensorboard_by_step(
161
170
  p.start()
162
171
  for p in processes:
163
172
  p.join()
173
+ recursive_chmod(output_dirpath)
164
174
  logger.info(f"output has been saved to: {output_dirpath}")
@@ -0,0 +1,259 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import itertools
16
+ import os
17
+ from collections import defaultdict
18
+ from dataclasses import dataclass
19
+
20
+ import pandas as pd
21
+ import torch
22
+ from torch.utils.tensorboard import SummaryWriter
23
+
24
+ from msprobe.core.common.const import FileCheckConst, MonitorConst
25
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
26
+ from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner
27
+ from msprobe.pytorch.common.log import logger
28
+
29
+
30
+ class BCOLORS:
31
+ HEADER = '\033[95m'
32
+ OKBLUE = '\033[94m'
33
+ OKCYAN = '\033[96m'
34
+ OKGREEN = '\033[92m'
35
+ WARNING = '\033[93m'
36
+ FAIL = '\033[91m'
37
+ ENDC = '\033[0m'
38
+ BOLD = '\033[1m'
39
+ UNDERLINE = '\033[4m'
40
+
41
+
42
+ @dataclass
43
+ class WriterInput:
44
+ path: str
45
+ ad_rules: list
46
+ job_id: str
47
+ anomaly_factory: AnomalyDataFactory = None
48
+ ndigits: int = 6
49
+ step_count_per_record: int = 1
50
+
51
+
52
+ class BaseWriterWithAD:
53
+ def __init__(self, writer_input: WriterInput):
54
+ self.tag2scalars = {}
55
+ self.ad_rules = writer_input.ad_rules
56
+ self.job_id = writer_input.job_id
57
+ self.anomaly_factory = writer_input.anomaly_factory
58
+ self.anomalies = []
59
+ self.ndigits = writer_input.ndigits
60
+ self.beta = 0.99
61
+
62
+ @staticmethod
63
+ def stack_tensors(tensor_list):
64
+ """
65
+ Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group,
66
+ stack them separately, migrate xpu_group to cpu, and then restore in the order of input.
67
+
68
+ :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')]
69
+ :return: result: list of float
70
+ """
71
+ cpu_tensors = []
72
+ xpu_tensors = []
73
+
74
+ for tensor in tensor_list:
75
+ if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu':
76
+ # 将device上的tensor先stack后to cpu
77
+ xpu_tensors.append(tensor)
78
+ else:
79
+ cpu_tensors.append(tensor)
80
+
81
+ xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([])
82
+
83
+ # 按照输入的顺序恢复
84
+ result = []
85
+ cpu_tensors_idx, xpu_tensors_idx = 0, 0
86
+ for tensor in tensor_list:
87
+ if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu':
88
+ result.append(xpu_stack[xpu_tensors_idx])
89
+ xpu_tensors_idx += 1
90
+ else:
91
+ result.append(cpu_tensors[cpu_tensors_idx])
92
+ cpu_tensors_idx += 1
93
+
94
+ return result
95
+
96
+ def get_anomalies(self):
97
+ """返回已检测到的异常列表
98
+ """
99
+ return self.anomalies
100
+
101
+ def clear_anomalies(self):
102
+ self.anomalies.clear()
103
+
104
+ def add_scalar(self, tag, scalar_value, global_step=None):
105
+ """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies.
106
+ Args:
107
+ tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min').
108
+ scalar_value (float): scalar_value.
109
+ global_step (int): global_step.
110
+ Returns:
111
+ None
112
+ """
113
+ if not self.ad_rules or tag[-1] in ["shape", "dtype"]:
114
+ return
115
+ if isinstance(scalar_value, torch.Tensor):
116
+ scalar_value = scalar_value.item()
117
+ avg = self._update_tag2scalars(tag, scalar_value)
118
+ detected, rule_name = self._ad(scalar_value, history=avg)
119
+ if detected:
120
+ if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]:
121
+ return
122
+ exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, "
123
+ f"current value {scalar_value}, history mean {avg}.")
124
+ logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}")
125
+ # append to self.anomalies for dump
126
+ if self.anomaly_factory:
127
+ self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
128
+
129
+ def write_metrics(self, ops, metric_value, step, prefix=''):
130
+ if not metric_value:
131
+ return
132
+ tensors = []
133
+ tags = list(itertools.product(metric_value.keys(), ops))
134
+ for op2tensor in metric_value.values():
135
+ tensors.extend(op2tensor.values())
136
+ if not tensors:
137
+ return
138
+
139
+ n_slices = len(tensors) // MonitorConst.SLICE_SIZE
140
+ with torch.no_grad():
141
+ for i in range(n_slices + 1):
142
+ begin = i * MonitorConst.SLICE_SIZE
143
+ end = (i + 1) * MonitorConst.SLICE_SIZE
144
+ if begin == len(tensors):
145
+ continue
146
+ metric_list = self.stack_tensors(tensors[begin:end])
147
+ for tag, metric in zip(tags[begin:end], metric_list):
148
+ self.add_scalar(tag, metric, step)
149
+
150
+ def _ad(self, scalar_value, history):
151
+ return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
152
+
153
+ def _update_tag2scalars(self, tag, scalar_value):
154
+ """Update the average and count of a scalar value associated with a tag.
155
+
156
+ This method is used to maintain a running average of scalar values for each tag.
157
+
158
+
159
+ Args:
160
+ tag (str): The tag identifier.
161
+ scalar_value (float): The scalar value to be added.
162
+
163
+ Returns:
164
+ float: The average value before update.
165
+ """
166
+ abs_scalar_value = abs(scalar_value)
167
+ if tag not in self.tag2scalars:
168
+ self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0}
169
+ avg = self.tag2scalars[tag]['avg']
170
+ self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value
171
+ self.tag2scalars[tag]['count'] += 1
172
+ return avg
173
+
174
+
175
+ class CSVWriterWithAD(BaseWriterWithAD):
176
+ def __init__(self, writer_input: WriterInput):
177
+ super().__init__(writer_input)
178
+
179
+ path = writer_input.path
180
+ self.log_dir = path
181
+ create_directory(path)
182
+ change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
183
+ self.context_dict = defaultdict(list)
184
+ self.header = []
185
+ self.step_count_per_record = writer_input.step_count_per_record
186
+
187
+ def get_step_interval(self, step):
188
+ count = step // self.step_count_per_record
189
+ return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1
190
+
191
+ def write_csv(self, prefix, step):
192
+ """
193
+ Args:
194
+ prefix[str]: prefix of output csv file e.g. grad_unreduced
195
+ step[int]
196
+ """
197
+ if len(self.context_dict) == 0:
198
+ return
199
+
200
+ ster_start, step_end = self.get_step_interval(step)
201
+ filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
202
+ if not os.path.exists(filepath):
203
+ data_frame = pd.DataFrame(columns=self.header)
204
+ write_df_to_csv(data_frame, filepath)
205
+
206
+ new_data = []
207
+ for name, metric_value in self.context_dict.items():
208
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
209
+ new_line.insert(2, step)
210
+ new_data.append(new_line)
211
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
212
+ write_df_to_csv(new_data, filepath, mode='a+', header=False)
213
+ self.context_dict = defaultdict(list)
214
+
215
+ def add_scalar(self, tag, scalar_value, global_step):
216
+ """
217
+ ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
218
+ """
219
+ super().add_scalar(tag, scalar_value, global_step)
220
+
221
+ name = tag[0].split('/')[0]
222
+ if isinstance(scalar_value, torch.Tensor):
223
+ value = scalar_value.item()
224
+ elif isinstance(scalar_value, torch.Size):
225
+ value = list(scalar_value)
226
+ else:
227
+ value = scalar_value
228
+ self.context_dict[name].append(value)
229
+
230
+ def write_metrics(self, ops, metric_value, step, prefix='', **kwargs):
231
+ super().write_metrics(ops, metric_value, step, prefix='')
232
+
233
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"):
234
+ self.header = MonitorConst.CSV_HEADER_MICRO_STEP + ops
235
+ else:
236
+ self.header = MonitorConst.CSV_HEADER + ops
237
+ self.write_csv(prefix, step)
238
+
239
+ def close(self):
240
+ pass
241
+
242
+
243
+ class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD):
244
+ def __init__(self, writer_input: WriterInput):
245
+
246
+ path = writer_input.path
247
+ if not os.path.exists(path):
248
+ create_directory(path)
249
+ try:
250
+ super(SummaryWriter, self).__init__(writer_input)
251
+ super().__init__(path)
252
+ except Exception as e:
253
+ logger.error(f'error when init summary writer at {path}: {e}')
254
+ raise ValueError("Init summary writer error.") from e
255
+
256
+ def add_scalar(self, tag, scalar_value, global_step):
257
+ super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step)
258
+ tag = f'{tag[0]}_{tag[1]}'
259
+ super().add_scalar(tag, scalar_value, global_step)
@@ -24,6 +24,7 @@ import torch.nn as nn
24
24
  from msprobe.core.common.const import MonitorConst
25
25
  from msprobe.core.common.file_utils import load_yaml
26
26
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name
27
+ from msprobe.pytorch.common.log import logger
27
28
 
28
29
  try:
29
30
  import torch_npu
@@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", [])
37
38
 
38
39
  StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml")
39
40
  StackBlackList = load_yaml(StackBlackListPath).get("stack", [])
41
+ MAX_STRING_LENGTH = 1000
40
42
 
41
43
  distributed_func = {}
42
44
  for f in dir(dist):
@@ -139,6 +141,8 @@ def get_process_group(process_group):
139
141
 
140
142
 
141
143
  def stack_filter(stack):
144
+ if len(stack) > MAX_STRING_LENGTH:
145
+ logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
142
146
  for pattern in StackBlackList:
143
147
  if re.search(pattern, stack):
144
148
  return False
@@ -188,10 +192,12 @@ def update_data(old, new):
188
192
 
189
193
 
190
194
  def is_target_line(codeline):
191
- stack = get_callstack()
192
- whole_stack = ';'.join(stack)
193
195
  if codeline == []:
194
196
  return True
197
+ stack = get_callstack()
198
+ whole_stack = ';'.join(stack)
199
+ if len(whole_stack) > MAX_STRING_LENGTH:
200
+ logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.')
195
201
  for pattern in codeline:
196
202
  if re.search(pattern, whole_stack):
197
203
  return True