mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
+ from msprobe.core.common.log import logger
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
19
20
 
@@ -41,8 +42,10 @@ class SelfCheckToolFactory:
41
42
  def create(config: DebuggerConfig):
42
43
  tool = SelfCheckToolFactory.tools.get(config.level)
43
44
  if not tool:
44
- raise Exception(f"{config.level} is not supported.")
45
+ logger.error(f"{config.level} is not supported.")
46
+ raise ValueError
45
47
  tool = tool.get(config.execution_mode)
46
48
  if not tool:
47
- raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.")
49
+ logger.error(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.")
50
+ raise ValueError
48
51
  return tool(config)
@@ -16,6 +16,7 @@
16
16
  import os
17
17
  import threading
18
18
  from typing import Dict, Union, Tuple
19
+ import time
19
20
 
20
21
  from msprobe.core.common.utils import is_int
21
22
  from msprobe.core.common.file_utils import create_directory, check_path_before_create
@@ -40,8 +41,12 @@ class GlobalContext:
40
41
  def __new__(cls, *args, **kwargs):
41
42
  if cls._instance is None:
42
43
  cls._instance_lock.acquire()
43
- cls._instance = object.__new__(cls)
44
- cls._instance_lock.release()
44
+ try:
45
+ cls._instance = object.__new__(cls)
46
+ except Exception as e:
47
+ raise RuntimeError("grad_probe global context init failed") from e
48
+ finally:
49
+ cls._instance_lock.release()
45
50
  return cls._instance
46
51
 
47
52
  def init_context(self, config_dict: Dict):
@@ -69,6 +74,8 @@ class GlobalContext:
69
74
  else:
70
75
  logger.warning("The output_path exists, the data will be covered.")
71
76
 
77
+ self._setting[GradConst.TIME_STAMP] = str(int(time.time()))
78
+
72
79
  def get_context(self, key: str):
73
80
  if key not in self._setting:
74
81
  logger.warning(f"Unrecognized {key}.")
@@ -111,7 +111,8 @@ class CSVGenerator(Process):
111
111
  output_path = context.get_context(GradConst.OUTPUT_PATH)
112
112
  self.level = context.get_context(GradConst.LEVEL)
113
113
  self.bounds = context.get_context(GradConst.BOUNDS)
114
- self.dump_dir = f"{output_path}/rank{rank_id}/Dump/"
114
+ time_stamp = context.get_context(GradConst.TIME_STAMP)
115
+ self.dump_dir = f"{output_path}/rank{rank_id}/Dump{time_stamp}/"
115
116
  self.save_dir = f"{output_path}/rank{rank_id}/"
116
117
  self.current_step = None
117
118
  self.stop_event = multiprocessing.Event()
@@ -15,6 +15,7 @@
15
15
 
16
16
  import hashlib
17
17
  from abc import ABC, abstractmethod
18
+ import zlib
18
19
 
19
20
  import mindspore
20
21
  from mindspore import ops
@@ -76,8 +77,8 @@ class CsvMd5(CsvItem):
76
77
  def generate_csv_content(csv_input):
77
78
  grad = csv_input.grad
78
79
  tensor_bytes = grad.float().numpy().tobytes()
79
- md5_hash = hashlib.md5(tensor_bytes)
80
- return [md5_hash.hexdigest()]
80
+ md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
81
+ return [md5_hash]
81
82
 
82
83
 
83
84
  @register_csv_item(GradConst.DISTRIBUTION)
@@ -49,12 +49,10 @@ class HookInput:
49
49
  self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
50
50
  self.rank_id = get_rank_id()
51
51
  output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
52
- self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump")
52
+ time_stamp = grad_context.get_context(GradConst.TIME_STAMP)
53
+ self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", f"Dump{time_stamp}")
53
54
  self.save_dir = os.path.join(output_path, f"rank{self.rank_id}")
54
55
  self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH)
55
- if os.path.exists(self.save_dir):
56
- logger.warning(f"Delete existing path {self.save_dir}.")
57
- remove_path(self.save_dir)
58
56
  self.level = grad_context.get_context(GradConst.LEVEL)
59
57
  self.bounds = grad_context.get_context(GradConst.BOUNDS)
60
58
  self.mode = mindspore.get_context("mode")
@@ -0,0 +1,111 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ import mindspore as ms
18
+ from mindspore.ops.primitive import Primitive
19
+
20
+ from msprobe.core.common.utils import Const
21
+ from msprobe.core.service import BaseService
22
+ from msprobe.mindspore.cell_processor import CellProcessor
23
+ from msprobe.mindspore.common.log import logger
24
+ from msprobe.mindspore.common.utils import (
25
+ get_rank_if_initialized,
26
+ is_mindtorch,
27
+ get_cells_and_names_with_index
28
+ )
29
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate
30
+ from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager
31
+ from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
32
+ from msprobe.mindspore.dump.jit_dump import JitDump
33
+
34
+ try:
35
+ from mindspore.common._pijit_context import PIJitCaptureContext
36
+ except ImportError:
37
+ pijit_label = False
38
+ else:
39
+ pijit_label = True
40
+
41
+
42
+ class MindsporeService(BaseService):
43
+ @property
44
+ def _get_framework_type(self):
45
+ return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
46
+
47
+ @staticmethod
48
+ def _get_current_rank():
49
+ return get_rank_if_initialized()
50
+
51
+ def empty(self, *args, **kwargs):
52
+ pass
53
+
54
+ def _init_specific_components(self):
55
+ self.logger = logger
56
+ self.api_register = get_api_register()
57
+ self.primitive_hook_service = PrimitiveHookService(self)
58
+ self.cell_processor = CellProcessor(self.data_collector.scope)
59
+ self.hook_manager = MindsproeHookManager(self.data_collector, self.config)
60
+ self._setup_jit_context()
61
+ self.api_template = ApiTemplate
62
+
63
+ def _setup_jit_context(self):
64
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
65
+ JitDump.set_config(self.config)
66
+ JitDump.set_data_collector(self.data_collector)
67
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
68
+ ms.common.api._MindsporeFunctionExecutor = JitDump
69
+ else:
70
+ ms.common.api._JitExecutor = JitDump
71
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
72
+ if pijit_label:
73
+ PIJitCaptureContext.__enter__ = self.empty
74
+ PIJitCaptureContext.__exit__ = self.empty
75
+
76
+ def _register_module_hook(self):
77
+ self.cell_processor.register_cell_hook(self.model, self.build_hook, self.config)
78
+ self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
79
+
80
+ def _register_hook(self):
81
+ self._register_primitive_hook()
82
+
83
+ def _register_primitive_hook(self):
84
+ if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
85
+ return
86
+ if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
87
+ return
88
+
89
+ primitive_set = set()
90
+ cells_and_names_with_index, _ = get_cells_and_names_with_index(self.model)
91
+ for cells_and_names in cells_and_names_with_index.values():
92
+ for _, cell in cells_and_names:
93
+ for attribute, value in vars(cell).items():
94
+ if isinstance(value, Primitive):
95
+ primitive_set.add((attribute, value))
96
+
97
+ for pname, primitive in primitive_set:
98
+ primitive_class_name = primitive.__class__.__name__
99
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
100
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
101
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
102
+ primitive_combined_name)})
103
+ primitive.__class__ = new_primitive
104
+
105
+ def _reset_status(self):
106
+ super()._reset_status()
107
+ self.primitive_hook_service.primitive_counters.clear()
108
+ JitDump.jit_count = defaultdict(int)
109
+
110
+ def _change_jit_switch(self, status):
111
+ JitDump.jit_dump_switch = status
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from mindspore import nn
18
+ from mindspore import communication
19
+ from msprobe.mindspore.monitor.utils import logger
20
+ from msprobe.mindspore.common.utils import is_mindtorch
21
+ if is_mindtorch():
22
+ import torch
23
+
24
+
25
+ def is_valid_instance(model):
26
+ return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell)
27
+
28
+
29
+ def get_submodules(model):
30
+ if not is_valid_instance(model):
31
+ logger.info("Counter invalid model, nothing to hook")
32
+ return {}
33
+ return model.named_modules() if is_mindtorch() else model.cells_and_names()
34
+
35
+
36
+ def get_parameters(model):
37
+ if not is_valid_instance(model):
38
+ return {}
39
+ if is_mindtorch():
40
+ return model.named_parameters()
41
+ else:
42
+ return model.parameters_and_names()
43
+
44
+
45
+ def get_rank():
46
+ if comm_is_initialized():
47
+ return communication.get_rank()
48
+ return 0
49
+
50
+
51
+ def comm_is_initialized():
52
+ return communication.GlobalComm.INITED
@@ -0,0 +1,237 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ import os
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
+
21
+ import pandas as pd
22
+ from mindspore import ops
23
+ from mindspore import Tensor
24
+ from mindspore import _no_grad
25
+
26
+ from msprobe.core.common.log import logger
27
+ from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv
28
+ from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner
29
+ from msprobe.core.common.const import FileCheckConst, MonitorConst
30
+
31
+
32
+ class BCOLORS:
33
+ HEADER = '\033[95m'
34
+ OKBLUE = '\033[94m'
35
+ OKCYAN = '\033[96m'
36
+ OKGREEN = '\033[92m'
37
+ WARNING = '\033[93m'
38
+ FAIL = '\033[91m'
39
+ ENDC = '\033[0m'
40
+ BOLD = '\033[1m'
41
+ UNDERLINE = '\033[4m'
42
+
43
+
44
+ @dataclass
45
+ class WriterInput:
46
+ path: str
47
+ ad_rules: list
48
+ job_id: str
49
+ anomaly_factory: AnomalyDataFactory = None
50
+ ndigits: int = 6
51
+ step_count_per_record: int = 1
52
+
53
+
54
+ class BaseWriterWithAD:
55
+ def __init__(self, writer_input: WriterInput):
56
+ self.tag2scalars = {}
57
+ self.ad_rules = writer_input.ad_rules
58
+ self.job_id = writer_input.job_id
59
+ self.anomaly_factory = writer_input.anomaly_factory
60
+ self.anomalies = []
61
+ self.ndigits = writer_input.ndigits
62
+ self.beta = 0.99
63
+
64
+ @staticmethod
65
+ def stack_tensors(tensor_list):
66
+ """
67
+ Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group,
68
+ stack them separately, migrate xpu_group to cpu, and then restore in the order of input.
69
+
70
+ :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')]
71
+ :return: result: list of float
72
+ """
73
+ cpu_tensors = []
74
+ xpu_tensors = []
75
+
76
+ for tensor in tensor_list:
77
+ if isinstance(tensor, Tensor):
78
+ # 将device上的tensor先stack后to cpu
79
+ xpu_tensors.append(tensor)
80
+ else:
81
+ cpu_tensors.append(tensor)
82
+
83
+ xpu_stack = ops.stack(xpu_tensors).tolist() if xpu_tensors else ops.tensor([])
84
+
85
+ # 按照输入的顺序恢复
86
+ result = []
87
+ cpu_tensors_idx, xpu_tensors_idx = 0, 0
88
+ for tensor in tensor_list:
89
+ if isinstance(tensor, Tensor):
90
+ result.append(xpu_stack[xpu_tensors_idx])
91
+ xpu_tensors_idx += 1
92
+ else:
93
+ result.append(cpu_tensors[cpu_tensors_idx])
94
+ cpu_tensors_idx += 1
95
+
96
+ return result
97
+
98
+ def get_anomalies(self):
99
+ """返回已检测到的异常列表
100
+ """
101
+ return self.anomalies
102
+
103
+ def clear_anomalies(self):
104
+ self.anomalies.clear()
105
+
106
+ def add_scalar(self, tag, scalar_value, global_step=None, need_explain=False):
107
+ """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies.
108
+ Args:
109
+ tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min').
110
+ scalar_value (float): scalar_value.
111
+ global_step (int): global_step.
112
+ Returns:
113
+ None
114
+ """
115
+ if not self.ad_rules or tag[-1] in ["shape", "dtype"]:
116
+ return
117
+ if isinstance(scalar_value, Tensor):
118
+ scalar_value = scalar_value.item()
119
+ avg = self._update_tag2scalars(tag, scalar_value)
120
+ detected, rule_name = self._ad(scalar_value, history=avg)
121
+ if detected:
122
+ if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]:
123
+ return
124
+ exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, "
125
+ f"current value {scalar_value}, history mean {avg}.")
126
+ logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}")
127
+ # append to self.anomalies for dump
128
+ if self.anomaly_factory:
129
+ self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step))
130
+
131
+ def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False):
132
+ if not metric_value:
133
+ return
134
+ tensors = []
135
+ tags = list(itertools.product(metric_value.keys(), op_list))
136
+ for op2tensor in metric_value.values():
137
+ tensors.extend(op2tensor.values())
138
+
139
+ if not tensors:
140
+ return
141
+
142
+ with _no_grad():
143
+ metric_list = self.stack_tensors(tensors)
144
+ for tag, metric in zip(tags, metric_list):
145
+ self.add_scalar(tag, metric, step, need_explain)
146
+
147
+ def _ad(self, scalar_value, history):
148
+ return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value)
149
+
150
+ def _update_tag2scalars(self, tag, scalar_value):
151
+ """Update the average and count of a scalar value associated with a tag.
152
+
153
+ This method is used to maintain a running average of scalar values for each tag.
154
+
155
+
156
+ Args:
157
+ tag (str): The tag identifier.
158
+ scalar_value (float): The scalar value to be added.
159
+
160
+ Returns:
161
+ float: The average value before update.
162
+ """
163
+ abs_scalar_value = abs(scalar_value)
164
+ if tag not in self.tag2scalars:
165
+ self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0}
166
+ avg = self.tag2scalars[tag]['avg']
167
+ self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value
168
+ self.tag2scalars[tag]['count'] += 1
169
+ return avg
170
+
171
+
172
+ class CSVWriterWithAD(BaseWriterWithAD):
173
+ def __init__(self, writer_input: WriterInput):
174
+ super().__init__(writer_input)
175
+
176
+ path = writer_input.path
177
+ self.log_dir = path
178
+ create_directory(path)
179
+ change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY)
180
+ self.context_dict = defaultdict(list)
181
+ self.header = []
182
+ self.step_count_per_record = writer_input.step_count_per_record
183
+
184
+ def get_step_interval(self, step):
185
+ count = step // self.step_count_per_record
186
+ return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1
187
+
188
+ def write_csv(self, prefix, step):
189
+ """
190
+ Args:
191
+ prefix[str]: prefix of output csv file e.g. grad_unreduced
192
+ step[int]
193
+ """
194
+ if len(self.context_dict) == 0:
195
+ return
196
+
197
+ ster_start, step_end = self.get_step_interval(step)
198
+ filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv')
199
+ if not os.path.exists(filepath):
200
+ data_frame = pd.DataFrame(columns=self.header)
201
+ write_df_to_csv(data_frame, filepath)
202
+
203
+ new_data = []
204
+ for name, metric_value in self.context_dict.items():
205
+ new_line = name.split(MonitorConst.NAME_SEP) + metric_value
206
+ new_line.insert(2, step)
207
+ new_data.append(new_line)
208
+ new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan")
209
+ write_df_to_csv(new_data, filepath, mode='a+', header=False)
210
+ self.context_dict = defaultdict(list)
211
+
212
+ def add_scalar(self, tag, scalar_value, global_step, need_explain=False):
213
+ """
214
+ ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min')
215
+ """
216
+ super().add_scalar(tag, scalar_value, global_step, need_explain=False)
217
+ split_name = tag[0].split('/')
218
+ name = split_name[0]
219
+ if need_explain:
220
+ if 'pre' in split_name[-1]:
221
+ name += '.input'
222
+ if 'post' in split_name[-1]:
223
+ name += '.output'
224
+ self.context_dict[name].append(scalar_value)
225
+
226
+ def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False, **kwargs):
227
+ need_explain = prefix == 'other'
228
+ super().write_metrics(op_list, metric_value, step, prefix='', need_explain=need_explain)
229
+
230
+ if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"):
231
+ self.header = MonitorConst.CSV_HEADER_MICRO_STEP + op_list
232
+ else:
233
+ self.header = MonitorConst.CSV_HEADER + op_list
234
+ self.write_csv(prefix, step)
235
+
236
+ def close(self):
237
+ pass
@@ -281,7 +281,7 @@ def create_hooks(context, monitor):
281
281
  global RANK
282
282
  pre_hooks = []
283
283
  hooks = []
284
- RANK = str(get_rank())
284
+ RANK = get_rank()
285
285
  if communication.GlobalComm.INITED and RANK not in monitor.module_rank_list and monitor.module_rank_list != []:
286
286
  return [pre_hooks, hooks]
287
287
 
@@ -46,6 +46,8 @@ def get_max(x: Tensor):
46
46
 
47
47
  @_no_grad()
48
48
  def get_zeros(x: Tensor, eps: float):
49
+ if x.numel() == 0:
50
+ return Tensor(float('nan'))
49
51
  return mint.sum(mint.abs(x) < eps) / x.numel()
50
52
 
51
53
 
@@ -54,10 +56,20 @@ def get_nans(t):
54
56
  return ops.isnan(t.astype(mstype.float32)).sum()
55
57
 
56
58
 
59
+ def get_shape(t):
60
+ return t.shape
61
+
62
+
63
+ def get_dtype(t):
64
+ return t.dtype
65
+
66
+
57
67
  FUNC_MAP = {"min" : get_min,
58
68
  "max" : get_max,
59
69
  "mean" : get_mean,
60
70
  "norm" : get_norm,
61
71
  "nans" : get_nans,
62
- "zeros": get_zeros
72
+ "zeros": get_zeros,
73
+ "shape": get_shape,
74
+ "dtype": get_dtype
63
75
  }