mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -17,9 +17,11 @@ import csv
17
17
  import os
18
18
  import copy
19
19
  import threading
20
+ import traceback
21
+ from datetime import datetime, timezone, timedelta
20
22
 
21
23
  from msprobe.core.common.const import Const, FileCheckConst
22
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
24
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json, check_path_before_create
23
25
  from msprobe.core.common.log import logger
24
26
  from msprobe.core.common.decorator import recursion_depth_decorator
25
27
 
@@ -35,11 +37,15 @@ class DataWriter:
35
37
  self.free_benchmark_file_path = None
36
38
  self.dump_tensor_data_dir = None
37
39
  self.debug_file_path = None
40
+ self.dump_error_info_path = None
38
41
  self.flush_size = 1000
42
+ self.larger_flush_size = 20000
39
43
  self.cache_data = {}
40
44
  self.cache_stack = {}
41
45
  self.cache_construct = {}
42
46
  self.cache_debug = {}
47
+ self.stat_stack_list = []
48
+ self._error_log_initialized = False
43
49
 
44
50
  @staticmethod
45
51
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -56,13 +62,54 @@ class DataWriter:
56
62
  if is_new_file:
57
63
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
58
64
 
65
+ @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders")
66
+ def _replace_stat_placeholders(self, data, stat_result):
67
+ if isinstance(data, dict):
68
+ keys = list(data.keys()) # 获取当前所有键
69
+ for key in keys: # 递归所有变量
70
+ value = data[key]
71
+ if key == Const.TENSOR_STAT_INDEX and isinstance(value, int):
72
+ if value >= 0:
73
+ idx = value
74
+ else:
75
+ return
76
+ stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4
77
+
78
+ new_entries = {
79
+ Const.TYPE: data["type"],
80
+ Const.DTYPE: data["dtype"],
81
+ Const.SHAPE: data["shape"],
82
+ Const.MAX: stat_values[0],
83
+ Const.MIN: stat_values[1],
84
+ Const.MEAN: stat_values[2],
85
+ Const.NORM: stat_values[3],
86
+ }
87
+ del data[key]
88
+
89
+ # 重构字典顺序
90
+ updated_dict = {}
91
+ # 通过插入排序后字段保证字段写入json的有序
92
+ updated_dict.update(new_entries)
93
+ # 遍历原字典其他字段(排除已删除的tensor_stat_index)
94
+ for k in data:
95
+ if k not in new_entries:
96
+ updated_dict[k] = data[k]
97
+ data.clear()
98
+ data.update(updated_dict)
99
+ else:
100
+ self._replace_stat_placeholders(value, stat_result)
101
+ elif isinstance(data, (list, tuple)):
102
+ for item in data:
103
+ self._replace_stat_placeholders(item, stat_result)
104
+
59
105
  def reset_cache(self):
60
106
  self.cache_data = {}
61
107
  self.cache_stack = {}
62
108
  self.cache_construct = {}
109
+ self.cache_debug = {}
63
110
 
64
111
  def initialize_json_file(self, **kwargs):
65
- if self.debug_file_path and not self.cache_debug:
112
+ if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug:
66
113
  # debug level case only create debug.json
67
114
  debug_dict = copy.deepcopy(kwargs)
68
115
  debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
@@ -85,12 +132,46 @@ class DataWriter:
85
132
  self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
86
133
  self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
87
134
  self.debug_file_path = dump_path_aggregation.debug_file_path
135
+ self.dump_error_info_path = dump_path_aggregation.dump_error_info_path
88
136
 
89
137
  def flush_data_periodically(self):
90
138
  dump_data = self.cache_data.get(Const.DATA)
91
- if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
139
+
140
+ if not dump_data or not isinstance(dump_data, dict):
141
+ return
142
+
143
+ length = len(dump_data)
144
+
145
+ threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size
146
+
147
+ if length % threshold == 0:
92
148
  self.write_json()
93
149
 
150
+ def write_error_log(self, message: str):
151
+ """
152
+ 写错误日志:
153
+ - 第一次调用时以 'w' 模式清空文件,之后都用 'a' 模式追加
154
+ - 添加时间戳
155
+ - 在 message 后写入当前的调用栈(方便追踪日志来源)
156
+ """
157
+ try:
158
+ mode = "w" if not self._error_log_initialized else "a"
159
+ self._error_log_initialized = True
160
+
161
+ check_path_before_create(self.dump_error_info_path)
162
+
163
+ with FileOpen(self.dump_error_info_path, mode) as f:
164
+ cst_timezone = timezone(timedelta(hours=8), name="CST")
165
+ timestamp = datetime.now(cst_timezone).strftime("%Y-%m-%d %H:%M:%S %z")
166
+ f.write(f"[{timestamp}] {message}\n")
167
+ f.write("Call stack (most recent call last):\n")
168
+
169
+ f.write("".join(traceback.format_stack()[:-1])) # 去掉自己这一层
170
+ f.write("\n")
171
+ except Exception as e:
172
+ # 如果连写日志都失败了,就打印到 stderr
173
+ logger.warning(f"[FallbackError] Failed to write error log: {e}")
174
+
94
175
  def update_data(self, new_data):
95
176
  with lock:
96
177
  if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
@@ -107,9 +188,13 @@ class DataWriter:
107
188
  else:
108
189
  dump_data.update(new_data)
109
190
 
110
- def update_stack(self, new_data):
191
+ def update_stack(self, name, stack_data):
111
192
  with lock:
112
- self.cache_stack.update(new_data)
193
+ api_list = self.cache_stack.get(stack_data)
194
+ if api_list is None:
195
+ self.cache_stack.update({stack_data: [name]})
196
+ else:
197
+ api_list.append(name)
113
198
 
114
199
  def update_construct(self, new_data):
115
200
  with lock:
@@ -124,7 +209,11 @@ class DataWriter:
124
209
  save_json(file_path, self.cache_data, indent=1)
125
210
 
126
211
  def write_stack_info_json(self, file_path):
127
- save_json(file_path, self.cache_stack, indent=1)
212
+ num, new_cache_stack = 0, {}
213
+ for key, value in self.cache_stack.items():
214
+ new_cache_stack[num] = [value, key]
215
+ num += 1
216
+ save_json(file_path, new_cache_stack, indent=1)
128
217
 
129
218
  def write_construct_info_json(self, file_path):
130
219
  save_json(file_path, self.cache_construct, indent=1)
@@ -132,8 +221,56 @@ class DataWriter:
132
221
  def write_debug_info_json(self, file_path):
133
222
  save_json(file_path, self.cache_debug, indent=1)
134
223
 
224
+ def append_stat_to_buffer(self, stat_vector):
225
+ """
226
+ 直接使用 Python list 存储 stat_vector,
227
+ 将 stat_vector 存入 self.stat_stack_list 的方式
228
+ """
229
+ self.stat_stack_list.append(stat_vector)
230
+ return len(self.stat_stack_list) - 1
231
+
232
+ def get_buffer_values_max(self, index):
233
+ if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
234
+ return self.stat_stack_list[index][0]
235
+ else:
236
+ logger.warning(f"stat_stack_list[{index}] The internal data is incomplete,"
237
+ f" and the maximum value cannot be obtained.")
238
+ return None
239
+
240
+ def get_buffer_values_min(self, index):
241
+ if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1:
242
+ return self.stat_stack_list[index][1]
243
+ else:
244
+ logger.warning(f"stat_stack_list[{index}] Internal data is incomplete"
245
+ f" and minimum values cannot be obtained.")
246
+ return None
247
+
248
+ def flush_stat_stack(self):
249
+ """
250
+ 在 flush 阶段,将所有存储的统计值从设备搬到 CPU,
251
+ 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表
252
+ """
253
+ if not self.stat_stack_list:
254
+ return []
255
+ result = [
256
+ [
257
+ x.item() if hasattr(x, "item") else x
258
+ for x in stat_values
259
+ ]
260
+ for stat_values in self.stat_stack_list
261
+ ]
262
+ self.stat_stack_list = []
263
+ return result
264
+
135
265
  def write_json(self):
136
266
  with lock:
267
+ # 在写 JSON 前,统一获取统计值
268
+ stat_result = self.flush_stat_stack()
269
+ # 遍历 cache_data,将占位符替换为最终统计值
270
+ if stat_result:
271
+ self._replace_stat_placeholders(self.cache_data, stat_result)
272
+ if self.cache_debug:
273
+ self._replace_stat_placeholders(self.cache_debug, stat_result)
137
274
  if self.cache_data:
138
275
  self.write_data_json(self.dump_file_path)
139
276
  if self.cache_stack:
@@ -143,24 +280,3 @@ class DataWriter:
143
280
  if self.cache_debug:
144
281
  self.write_debug_info_json(self.debug_file_path)
145
282
 
146
- def fill_stack_tensor_data(self):
147
- self.process_stat_data_recursive(self.cache_data)
148
-
149
- @recursion_depth_decorator("AsyncDump: DataWriter.process_stat_data_recursive", max_depth=Const.DUMP_MAX_DEPTH)
150
- def process_stat_data_recursive(self, data):
151
- if isinstance(data, dict):
152
- if "tensor_stat" in data.keys():
153
- tensor_stat = data["tensor_stat"]
154
- if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
155
- logger.warning("Some bad data in async dump")
156
- else:
157
- tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
158
- for index, stat in zip(tensor_stat_index, tensor_stat_data):
159
- data.update({index: stat.item()})
160
- del data["tensor_stat"]
161
- else:
162
- for key in data.keys():
163
- self.process_stat_data_recursive(data[key])
164
- elif isinstance(data, (list, tuple)):
165
- for i in data:
166
- self.process_stat_data_recursive(i)
@@ -0,0 +1,144 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.core.common.file_utils import FileChecker, load_json
21
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
22
+ from msprobe.core.common_config import CommonConfig
23
+
24
+
25
+ class BasePrecisionDebugger:
26
+ _instance = None
27
+ tasks_not_need_debugger = [Const.GRAD_PROBE]
28
+
29
+ def __new__(cls, *args, **kwargs):
30
+ if cls._instance is None:
31
+ cls._instance = super(BasePrecisionDebugger, cls).__new__(cls)
32
+ cls._instance.config = None
33
+ cls._instance.enable_dataloader = False
34
+ cls._instance.initialized = False
35
+ cls.service = None
36
+ cls.first_start = False
37
+ return cls._instance
38
+
39
+ def __init__(
40
+ self,
41
+ config_path=None,
42
+ task=None,
43
+ dump_path=None,
44
+ level=None,
45
+ step=None
46
+ ):
47
+ if self.initialized:
48
+ return
49
+ self.initialized = True
50
+ self._check_input_params(config_path, task, dump_path, level)
51
+ self.common_config, self.task_config = self._parse_config_path(config_path, task)
52
+ self.task = self.common_config.task
53
+ if step is not None:
54
+ self.common_config.step = get_real_step_or_rank(step, Const.STEP)
55
+
56
+ @staticmethod
57
+ def _check_input_params(config_path, task, dump_path, level):
58
+ if not config_path:
59
+ config_path = os.path.join(os.path.dirname(__file__), "../../config.json")
60
+ if config_path is not None:
61
+ if not isinstance(config_path, str):
62
+ raise MsprobeException(
63
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
64
+ file_checker = FileChecker(
65
+ file_path=config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
66
+ file_checker.common_check()
67
+
68
+ if task is not None and task not in Const.TASK_LIST:
69
+ raise MsprobeException(
70
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
71
+
72
+ if dump_path is not None:
73
+ if not isinstance(dump_path, str):
74
+ raise MsprobeException(
75
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
76
+
77
+ if level is not None and level not in Const.LEVEL_LIST:
78
+ raise MsprobeException(
79
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
80
+
81
+ @staticmethod
82
+ def _get_task_config(task, json_config):
83
+ raise NotImplementedError("Subclass must implement _get_task_config")
84
+
85
+ @classmethod
86
+ def forward_backward_dump_end(cls):
87
+ instance = cls._instance
88
+ instance.stop()
89
+
90
+ @classmethod
91
+ def set_init_step(cls, step):
92
+ instance = cls._instance
93
+ if not instance:
94
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
95
+ check_init_step(step)
96
+ instance.service.init_step = step
97
+ instance.service.loop = 0
98
+
99
+ @classmethod
100
+ def register_custom_api(cls, module, api, api_prefix=None):
101
+ if not api_prefix:
102
+ api_prefix = getattr(module, "__name__", "Custom")
103
+ if not isinstance(api_prefix, str):
104
+ raise MsprobeException(
105
+ MsprobeException.INVALID_PARAM_ERROR, "api_prefix must be string")
106
+ if not hasattr(module, api):
107
+ raise MsprobeException(
108
+ MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}")
109
+ instance = cls._instance
110
+ if not instance:
111
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
112
+ instance.service.register_custom_api(module, api, api_prefix)
113
+
114
+ @classmethod
115
+ def restore_custom_api(cls, module, api):
116
+ if not hasattr(module, api):
117
+ raise MsprobeException(
118
+ MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}")
119
+ instance = cls._instance
120
+ if not instance:
121
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
122
+ instance.service.restore_custom_api(module, api)
123
+
124
+ @classmethod
125
+ def _get_instance(cls):
126
+ instance = cls._instance
127
+ if not instance:
128
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
129
+ if instance.task in BasePrecisionDebugger.tasks_not_need_debugger:
130
+ instance = None
131
+ return instance
132
+
133
+ def _parse_config_path(self, json_file_path, task):
134
+ if not json_file_path:
135
+ json_file_path = os.path.join(os.path.dirname(__file__), "../../config.json")
136
+ json_config = load_json(json_file_path)
137
+ common_config = CommonConfig(json_config)
138
+ if task:
139
+ task_config = self._get_task_config(task, json_config)
140
+ else:
141
+ if not common_config.task:
142
+ common_config.task = Const.STATISTICS
143
+ task_config = self._get_task_config(common_config.task, json_config)
144
+ return common_config, task_config
@@ -52,7 +52,7 @@ class GradConst:
52
52
  BOUNDS_MINIMUM = -2**63
53
53
  BOUNDS_MAXIMUM = 2**63 - 1
54
54
 
55
- # file safty
55
+ # file safety
56
56
  DATA_DIR_AUTHORITY = 0o750
57
57
  DATA_FILE_AUTHORITY = 0o640
58
58
  DIRECTORY_LENGTH = 4096
@@ -121,7 +121,7 @@ class GradComparator:
121
121
  similarities = {}
122
122
  logger.info(f"{len(steps)} steps will be compared")
123
123
  grad_weight_order = cls._get_grad_weight_order(path1, path2)
124
- for step in tqdm(steps, desc="culculate similarities (by step)"):
124
+ for step in tqdm(steps, desc="calculate similarities (by step)"):
125
125
  grad_files = cls._get_matched_grad_files(path1, path2, step)
126
126
  same_count_summary = 0
127
127
  total_count_summary = 0
@@ -82,7 +82,7 @@ class ListCache(list):
82
82
  if len(self) == 0:
83
83
  return
84
84
  if not self._output_file:
85
- logger.warning("dumpfile path is not setted")
85
+ logger.warning("dumpfile path is not set.")
86
86
  write_csv(self, self._output_file)
87
87
  logger.info(f"write {len(self)} items to {self._output_file}.")
88
88
  self.clear()
@@ -0,0 +1,242 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from abc import ABC, abstractmethod
18
+ import os
19
+
20
+ from msprobe.core.common.runtime import Runtime
21
+ from msprobe.core.common.utils import Const
22
+ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs)
23
+
24
+
25
+ class HookSet:
26
+ def __init__(self, forward_hook=None, forward_pre_hook=None, backward_hook=None, backward_pre_hook=None):
27
+ self.forward_hook = forward_hook
28
+ self.forward_pre_hook = forward_pre_hook
29
+ self.backward_hook = backward_hook
30
+ self.backward_pre_hook = backward_pre_hook
31
+
32
+
33
+ class BaseHookManager(ABC):
34
+ inner_switch = False
35
+ hook_handle_dict = {}
36
+ params_grad_info = {}
37
+
38
+ def __init__(self, data_collector, config, attl_manager=None):
39
+ self.data_collector = data_collector
40
+ self.config = config
41
+ self.attl_manager = attl_manager
42
+
43
+ @property
44
+ def _pid(self):
45
+ return os.getpid()
46
+
47
+ @property
48
+ @abstractmethod
49
+ def _is_recompute(self):
50
+ pass
51
+
52
+ @staticmethod
53
+ @abstractmethod
54
+ def _no_grad_context():
55
+ pass
56
+
57
+ @staticmethod
58
+ @abstractmethod
59
+ def _add_count(name):
60
+ pass
61
+
62
+ @staticmethod
63
+ @abstractmethod
64
+ def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs):
65
+ pass
66
+
67
+ @staticmethod
68
+ def _clear_input_kwargs(module):
69
+ if hasattr(module, 'msprobe_input_kwargs'):
70
+ del module.msprobe_input_kwargs
71
+
72
+ @abstractmethod
73
+ def build_hook(self):
74
+ pass
75
+
76
+ @abstractmethod
77
+ def _get_params_dict(self, module):
78
+ pass
79
+
80
+ @abstractmethod
81
+ def _need_exchange(self, module):
82
+ pass
83
+
84
+ def _register_param_hook(self, name, module, params_dict):
85
+ ori_name = name.rsplit(Const.SEP, 2)[0]
86
+ grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
87
+ # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
88
+ setattr(module, 'params_grad_name', grad_name)
89
+ # data_mode为forward时,不注册参数hook
90
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
91
+ for param_name, param in params_dict.items():
92
+ if param.requires_grad:
93
+ name = ori_name + Const.SEP + param_name
94
+ old_handle = BaseHookManager.hook_handle_dict.get(name)
95
+ if old_handle and hasattr(old_handle, "remove"):
96
+ old_handle.remove()
97
+ handle = param.register_hook(self._build_grad_hook(module, ori_name, param_name))
98
+ BaseHookManager.hook_handle_dict[name] = handle
99
+
100
+ def _init_params_grad_info(self, module, params_dict):
101
+ '''
102
+ 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
103
+ '''
104
+ if not params_dict:
105
+ return
106
+ if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
107
+ grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
108
+ # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
109
+ if not BaseHookManager.params_grad_info.get(grad_name):
110
+ data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
111
+ # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
112
+ if data_info.get(grad_name):
113
+ # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
114
+ self.data_collector.handle_data(grad_name, data_info,
115
+ flush=self.data_collector.data_processor.is_terminated)
116
+ # 记录当前模块的参数梯度信息已占位
117
+ BaseHookManager.params_grad_info[grad_name] = True
118
+
119
+ def _should_execute_hook(self, hook_type, module, is_forward):
120
+ is_module_hook = hook_type == Const.MODULE
121
+ if is_module_hook and not Runtime.is_running:
122
+ return False
123
+ elif not is_module_hook and is_forward and not Runtime.is_running:
124
+ return False
125
+ elif not is_module_hook and not is_forward and not module.forward_data_collected:
126
+ return False
127
+ if BaseHookManager.inner_switch:
128
+ return False
129
+ if not self.data_collector or self.data_collector.data_processor.is_terminated:
130
+ return False
131
+ return True
132
+
133
+ def _build_grad_hook(self, module, ori_name, param_name):
134
+ def hook_fn(grad):
135
+ if not self._should_execute_hook(Const.MODULE, module, False):
136
+ return
137
+ BaseHookManager.inner_switch = True
138
+ self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad)
139
+ BaseHookManager.inner_switch = False
140
+ return
141
+ return hook_fn
142
+
143
+ def _build_forward_pre_hook(self, hook_type, full_name, api_name):
144
+ def forward_pre_hook(module, args, kwargs=None):
145
+ if hook_type == Const.MODULE:
146
+ return
147
+ if not self._should_execute_hook(hook_type, module, True):
148
+ return
149
+ if kwargs is None:
150
+ kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {}
151
+ with self._no_grad_context():
152
+ BaseHookManager.inner_switch = False
153
+ module.forward_data_collected = True
154
+ self._add_count(api_name)
155
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
156
+ self.data_collector.update_api_or_module_name(full_name)
157
+ if getattr(self.config, "online_run_ut", False):
158
+ BaseHookManager.inner_switch = False
159
+ return
160
+ self.data_collector.forward_input_data_collect(
161
+ full_name,
162
+ module,
163
+ self._pid,
164
+ module_input_output,
165
+ self._is_recompute
166
+ )
167
+ BaseHookManager.inner_switch = False
168
+ return forward_pre_hook
169
+
170
+ def _build_forward_hook(self, hook_type, full_name):
171
+ def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
172
+ if not self._should_execute_hook(hook_type, module, True):
173
+ self._clear_input_kwargs(module)
174
+ return None
175
+ kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs)
176
+ BaseHookManager.inner_switch = True
177
+ self.data_collector.update_api_or_module_name(full_name)
178
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
179
+ with self._no_grad_context():
180
+ if getattr(self.config, "online_run_ut", False):
181
+ if self.data_collector.scope and not self.data_collector.scope.check(full_name):
182
+ return None
183
+ if self.attl_manager:
184
+ self.attl_manager.attl_send(full_name, args, kwargs, output)
185
+ BaseHookManager.inner_switch = False
186
+ return None
187
+ if hook_type == Const.MODULE:
188
+ params_dict = self._get_params_dict(module)
189
+ setattr(module_input_output, Const.PARAMS, params_dict)
190
+ if params_dict:
191
+ self._register_param_hook(full_name, module, params_dict)
192
+ self.data_collector.update_api_or_module_name(full_name)
193
+ self.data_collector.forward_data_collect(
194
+ full_name,
195
+ module,
196
+ self._pid,
197
+ module_input_output,
198
+ self._is_recompute
199
+ )
200
+ self._init_params_grad_info(module, params_dict)
201
+ else:
202
+ self.data_collector.forward_output_data_collect(
203
+ full_name,
204
+ module,
205
+ self._pid,
206
+ module_input_output,
207
+ self._is_recompute
208
+ )
209
+ self._clear_input_kwargs(module)
210
+
211
+ if self.data_collector.if_return_forward_new_output():
212
+ forward_new_output = self.data_collector.get_forward_new_output()
213
+ BaseHookManager.inner_switch = False
214
+ return forward_new_output
215
+
216
+ BaseHookManager.inner_switch = False
217
+ return output
218
+ return forward_hook
219
+
220
+ def _build_backward_hook(self, hook_type, full_name):
221
+ def backward_hook(module, grad_input, grad_output):
222
+ if not self._should_execute_hook(hook_type, module, False):
223
+ return
224
+ BaseHookManager.inner_switch = True
225
+ self.data_collector.update_api_or_module_name(full_name)
226
+ if getattr(self.config, "online_run_ut", False):
227
+ BaseHookManager.inner_switch = False
228
+ return
229
+ need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True
230
+ if need_exchange:
231
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
232
+ else:
233
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
234
+ self.data_collector.backward_data_collect(
235
+ full_name,
236
+ module,
237
+ self._pid,
238
+ module_input_output,
239
+ self._is_recompute
240
+ )
241
+ BaseHookManager.inner_switch = False
242
+ return backward_hook