mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -15,6 +15,7 @@
15
15
 
16
16
  import atexit
17
17
  import os
18
+ import traceback
18
19
 
19
20
  from msprobe.core.data_dump.scope import ScopeFactory
20
21
  from msprobe.core.data_dump.json_writer import DataWriter
@@ -41,7 +42,7 @@ class DataCollector:
41
42
  self.backward_module_names = {}
42
43
  self.optimizer_status = ""
43
44
  self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
44
- atexit.register(self.write_json)
45
+ atexit.register(self.write_json_at_exit)
45
46
 
46
47
  @property
47
48
  def dump_data_dir(self):
@@ -78,6 +79,11 @@ class DataCollector:
78
79
  def write_json(self):
79
80
  self.data_writer.write_json()
80
81
 
82
+ def write_json_at_exit(self):
83
+ if self.config.async_dump and self.config.task == Const.TENSOR:
84
+ self.data_processor.dump_async_data()
85
+ self.data_writer.write_json()
86
+
81
87
  def update_data(self, name, data_info):
82
88
  msg = f"msprobe is collecting data on {name}."
83
89
  if self.config.task == Const.OVERFLOW_CHECK:
@@ -89,88 +95,155 @@ class DataCollector:
89
95
  logger.debug(msg)
90
96
  self.data_writer.update_data(data_info)
91
97
 
92
- def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
93
- if self.config.task == Const.FREE_BENCHMARK:
94
- backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
95
- if self.check_scope_and_pid(self.scope, backward_name, pid):
96
- self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
97
- return
98
-
99
- if not self.check_scope_and_pid(self.scope, name, pid):
100
- return
98
+ def call_stack_collect(self, name):
99
+ stack_info = self.data_processor.analyze_api_call_stack(name)
100
+ self.data_writer.update_stack(name, stack_info)
101
101
 
102
- data_info = {}
103
- if self.config.task != Const.STRUCTURE:
104
- data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
105
- self.set_is_recomputable(data_info, is_recompute)
106
- if self.config.level == Const.LEVEL_L2:
107
- return
108
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
102
+ def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
103
+ try:
104
+
105
+ if self.config.task == Const.FREE_BENCHMARK:
106
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
107
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
108
+ self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
109
+ return
110
+
111
+ if not self.check_scope_and_pid(self.scope, name, pid):
112
+ return
113
+
114
+ data_info = {}
115
+ if self.config.task != Const.STRUCTURE:
116
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
117
+ self.set_is_recomputable(data_info, is_recompute)
118
+ if self.config.level == Const.LEVEL_L2:
119
+ return
120
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
121
+
122
+ except Exception:
123
+ tb = traceback.format_exc()
124
+ self.data_writer.write_error_log(
125
+ f"[ERROR] forward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
126
+ )
109
127
 
110
128
  def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
111
- self.update_construct(name)
112
- if not self.check_scope_and_pid(self.scope, name, pid):
113
- return
114
-
115
- data_info = {}
116
- if self.config.task != Const.STRUCTURE:
117
- data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
118
- self.set_is_recomputable(data_info, is_recompute)
119
- if self.config.level == Const.LEVEL_L2:
120
- return
121
- self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
122
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
129
+ try:
130
+
131
+ self.update_construct(name)
132
+ if not self.check_scope_and_pid(self.scope, name, pid):
133
+ return
134
+
135
+ data_info = {}
136
+ if self.config.task != Const.STRUCTURE:
137
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
138
+ self.set_is_recomputable(data_info, is_recompute)
139
+ if self.config.level == Const.LEVEL_L2:
140
+ return
141
+ self.call_stack_collect(name)
142
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
143
+
144
+ except Exception:
145
+ tb = traceback.format_exc()
146
+ self.data_writer.write_error_log(
147
+ f"[ERROR] forward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
148
+ )
149
+
150
+ def forward_data_collect_only_tensor(self, name, module, pid, module_input_output):
151
+ try:
152
+ if not self.check_scope_and_pid(self.scope, name, pid):
153
+ return
154
+ self.data_processor.analyze_forward(name, module, module_input_output)
155
+
156
+ except Exception:
157
+ tb = traceback.format_exc()
158
+ self.data_writer.write_error_log(
159
+ f"[ERROR] forward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
160
+ )
123
161
 
124
162
  def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
125
- self.update_construct(name)
126
- if not self.check_scope_and_pid(self.scope, name, pid):
127
- return
128
-
129
- data_info = {}
130
- if self.config.task != Const.STRUCTURE:
131
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
132
- self.set_is_recomputable(data_info, is_recompute)
133
- self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
134
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
163
+ try:
164
+
165
+ self.update_construct(name)
166
+ if not self.check_scope_and_pid(self.scope, name, pid):
167
+ return
168
+ data_info = {}
169
+ if self.config.task != Const.STRUCTURE:
170
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
171
+ self.set_is_recomputable(data_info, is_recompute)
172
+ self.call_stack_collect(name)
173
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
174
+
175
+ except Exception:
176
+ tb = traceback.format_exc()
177
+ self.data_writer.write_error_log(
178
+ f"[ERROR] forward_data_collect failed: name={name}, pid={pid}\n{tb}"
179
+ )
180
+
181
+ def backward_data_collect_only_tensor(self, name, module, pid, module_input_output, is_recompute=None):
182
+ try:
183
+ if not self.check_scope_and_pid(self.scope, name, pid):
184
+ return
185
+ self.data_processor.analyze_backward(name, module, module_input_output)
186
+
187
+ except Exception:
188
+ tb = traceback.format_exc()
189
+ self.data_writer.write_error_log(
190
+ f"[ERROR] backward_data_collect_only_tensor failed: name={name}, pid={pid}\n{tb}"
191
+ )
135
192
 
136
193
  def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
137
- self.update_construct(name)
138
- if not self.check_scope_and_pid(self.scope, name, pid):
139
- return
140
-
141
- data_info = {}
142
- if self.config.task != Const.STRUCTURE:
143
- data_info = self.data_processor.analyze_backward(name, module, module_input_output)
144
- if self.config.level == Const.LEVEL_L2:
145
- return
146
- # 获取执行反向的模块名称
147
- if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
148
- module_name = name.rsplit(Const.SEP, 2)[0]
149
- # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
150
- self.backward_module_names[module_name] = True
151
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
194
+ try:
195
+ self.update_construct(name)
196
+ if not self.check_scope_and_pid(self.scope, name, pid):
197
+ return
198
+ data_info = {}
199
+ if self.config.task != Const.STRUCTURE:
200
+ data_info = self.data_processor.analyze_backward(name, module, module_input_output)
201
+ if self.config.level == Const.LEVEL_L2:
202
+ return
203
+ if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
204
+ module_name = name.rsplit(Const.SEP, 2)[0]
205
+ self.backward_module_names[module_name] = True
206
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
207
+
208
+ except Exception:
209
+ tb = traceback.format_exc()
210
+ self.data_writer.write_error_log(
211
+ f"[ERROR] backward_data_collect failed: name={name}, pid={pid}\n{tb}"
212
+ )
152
213
 
153
214
  def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
154
- self.update_construct(name)
155
- if not self.check_scope_and_pid(self.scope, name, pid):
156
- return
157
-
158
- data_info = {}
159
- if self.config.task != Const.STRUCTURE:
160
- data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
161
- self.set_is_recomputable(data_info, is_recompute)
162
- self.handle_data(name, data_info)
215
+ try:
216
+ self.update_construct(name)
217
+ if not self.check_scope_and_pid(self.scope, name, pid):
218
+ return
219
+ data_info = {}
220
+ if self.config.task != Const.STRUCTURE:
221
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
222
+ self.set_is_recomputable(data_info, is_recompute)
223
+ self.handle_data(name, data_info)
224
+
225
+ except Exception:
226
+ tb = traceback.format_exc()
227
+ self.data_writer.write_error_log(
228
+ f"[ERROR] backward_input_data_collect failed: name={name}, pid={pid}\n{tb}"
229
+ )
163
230
 
164
231
  def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
165
- self.update_construct(name)
166
- if not self.check_scope_and_pid(self.scope, name, pid):
167
- return
168
-
169
- data_info = {}
170
- if self.config.task != Const.STRUCTURE:
171
- data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
172
- self.set_is_recomputable(data_info, is_recompute)
173
- self.handle_data(name, data_info)
232
+ try:
233
+ self.update_construct(name)
234
+ if not self.check_scope_and_pid(self.scope, name, pid):
235
+ return
236
+ data_info = {}
237
+ if self.config.task != Const.STRUCTURE:
238
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
239
+ self.set_is_recomputable(data_info, is_recompute)
240
+ self.handle_data(name, data_info)
241
+
242
+ except Exception:
243
+ tb = traceback.format_exc()
244
+ self.data_writer.write_error_log(
245
+ f"[ERROR] backward_output_data_collect failed: name={name}, pid={pid}\n{tb}"
246
+ )
174
247
 
175
248
  def update_construct(self, name):
176
249
  if self.config.level not in DataCollector.level_without_construct:
@@ -180,7 +253,10 @@ class DataCollector:
180
253
  self.optimizer_status_first_start[self.optimizer_status] = False
181
254
  self.data_writer.update_construct({name: self.optimizer_status})
182
255
  else:
183
- self.data_writer.update_construct({name: self.module_processor.api_parent_node})
256
+ if self.config.level == Const.LEVEL_MIX and \
257
+ not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)):
258
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
259
+
184
260
  self.data_writer.update_construct(self.module_processor.module_node)
185
261
 
186
262
  def handle_data(self, name, data_info, flush=False):
@@ -203,28 +279,33 @@ class DataCollector:
203
279
  self.data_processor.update_iter(current_iter)
204
280
 
205
281
  def params_data_collect(self, name, param_name, pid, data):
206
- grad_name = name + Const.SEP + Const.PARAMS_GRAD
207
- # 校验scope和pid,以及当前name是否有过反向计算
208
- if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
209
- # 如果没有反向计算,则需要清除之前占位写入的grad数据
210
- if self.data_writer.cache_data.get("data"):
211
- self.data_writer.cache_data.get("data").pop(grad_name, None)
212
- return
213
- data_info = self.data_processor.analyze_params(grad_name, param_name, data)
214
- self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
215
-
216
- def fill_stack_tensor_data(self):
217
- self.data_writer.fill_stack_tensor_data()
282
+ try:
283
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
284
+ self.update_api_or_module_name(grad_name)
285
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
286
+ if self.data_writer.cache_data.get("data"):
287
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
288
+ return
289
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
290
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
291
+ except Exception:
292
+ tb = traceback.format_exc()
293
+ self.data_writer.write_error_log(
294
+ f"[ERROR] params_data_collect failed: "
295
+ f"name={name}, param_name={param_name}, pid={pid}\n{tb}"
296
+ )
218
297
 
219
298
  def debug_data_collect_forward(self, variable, name_with_count):
220
-
221
299
  data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
222
- self.data_writer.update_debug({name_with_count: data_info})
300
+ name_with_count_category = name_with_count + Const.SEP + Const.DEBUG
301
+ self.data_writer.update_debug({name_with_count_category: data_info})
223
302
 
224
303
  def debug_data_collect_backward(self, variable, grad_name_with_count):
225
304
  # prepare all None nested data structure
226
305
  all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
227
- self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
306
+ grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG
307
+ self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info})
228
308
 
229
309
  # register tensor backward hook
230
- self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
310
+ self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category,
311
+ self.data_writer.cache_debug['data'])
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,17 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import copy
16
17
  import inspect
17
18
  import os
18
19
  from dataclasses import dataclass, is_dataclass
19
- from typing import Tuple, Dict, Optional, Any
20
20
  from functools import partial
21
- import copy
22
- from typing import Union
21
+ from typing import Tuple, Dict, Optional, Any, Union
23
22
 
24
23
  import numpy as np
25
24
 
26
25
  from msprobe.core.common.const import Const
26
+ from msprobe.core.common.file_utils import save_npy
27
27
  from msprobe.core.common.log import logger
28
28
  from msprobe.core.common.utils import convert_tuple, CompareException
29
29
 
@@ -79,21 +79,17 @@ class ModuleBackwardOutputs:
79
79
 
80
80
 
81
81
  class TensorStatInfo:
82
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
82
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
83
83
  self.max = max_val
84
84
  self.min = min_val
85
85
  self.mean = mean_val
86
86
  self.norm = norm_val
87
- self.stack_tensor_stat = stack_tensor_stat
88
87
 
89
88
 
90
89
  class BaseDataProcessor:
91
90
  _recursive_key_stack = []
92
- special_type = (
93
- np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
94
- bool, int, float, str, slice,
95
- type(Ellipsis)
96
- )
91
+ builtin_type = (bool, int, float, str, slice, type(Ellipsis))
92
+ np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray)
97
93
 
98
94
  def __init__(self, config, data_writer):
99
95
  self.data_writer = data_writer
@@ -120,7 +116,10 @@ class BaseDataProcessor:
120
116
  @staticmethod
121
117
  def analyze_api_call_stack(name):
122
118
  try:
123
- api_stack = inspect.stack()[5:]
119
+ if name.startswith("Primitive"):
120
+ api_stack = inspect.stack()[4:]
121
+ else:
122
+ api_stack = inspect.stack()[5:]
124
123
  except Exception as e:
125
124
  logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
126
125
  api_stack = None
@@ -129,12 +128,14 @@ class BaseDataProcessor:
129
128
  for (_, path, line, func, code, _) in api_stack:
130
129
  if not code:
131
130
  continue
131
+ if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \
132
+ Const.CALL_STACK_FLAG not in path:
133
+ continue
132
134
  stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
133
135
  stack_str.append(stack_line)
134
136
  else:
135
137
  stack_str.append(Const.WITHOUT_CALL_STACK)
136
- stack_info_struct = {name: stack_str}
137
- return stack_info_struct
138
+ return tuple(stack_str)
138
139
 
139
140
  @staticmethod
140
141
  def transfer_type(data):
@@ -178,20 +179,8 @@ class BaseDataProcessor:
178
179
  "invalid data_structure type or invalid index")
179
180
 
180
181
  @staticmethod
181
- def _convert_numpy_to_builtin(arg):
182
- type_mapping = {
183
- np.integer: int,
184
- np.floating: float,
185
- np.bool_: bool,
186
- np.complexfloating: complex,
187
- np.str_: str,
188
- np.byte: bytes,
189
- np.unicode_: str
190
- }
191
- for numpy_type, builtin_type in type_mapping.items():
192
- if isinstance(arg, numpy_type):
193
- return builtin_type(arg), type(arg).__name__
194
- return arg, ''
182
+ def is_distributed_op(module):
183
+ return getattr(module, "op_is_distributed", False)
195
184
 
196
185
  @staticmethod
197
186
  def _analyze_builtin(arg):
@@ -217,21 +206,40 @@ class BaseDataProcessor:
217
206
  return single_arg
218
207
 
219
208
  @staticmethod
220
- def _analyze_numpy(ndarray, numpy_type):
209
+ def _analyze_numpy(arg):
210
+ return {"type": type(arg).__name__, "value": arg.item()}
211
+
212
+ @staticmethod
213
+ def _analyze_ndarray(ndarray, _):
221
214
  ndarray_json = {}
222
215
  ndarray_json.update({'type': 'numpy.ndarray'})
223
216
  ndarray_json.update({'dtype': str(ndarray.dtype)})
224
217
  ndarray_json.update({'shape': ndarray.shape})
225
- if ndarray.size > 0:
226
- ndarray_json.update({"Max": np.max(ndarray).item()})
227
- ndarray_json.update({"Min": np.min(ndarray).item()})
228
- ndarray_json.update({"Mean": np.mean(ndarray).item()})
229
- ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
230
- else:
231
- ndarray_json.update({"Max": None})
232
- ndarray_json.update({"Min": None})
233
- ndarray_json.update({"Mean": None})
234
- ndarray_json.update({"Norm": None})
218
+
219
+ # 先初始化默认值
220
+ stats = {
221
+ "Max": None,
222
+ "Min": None,
223
+ "Mean": None,
224
+ "Norm": None
225
+ }
226
+
227
+ try:
228
+ # 只有非空时才尝试计算
229
+ if ndarray.size > 0:
230
+ stats = {
231
+ "Max": np.max(ndarray).item(),
232
+ "Min": np.min(ndarray).item(),
233
+ "Mean": np.mean(ndarray).item(),
234
+ "Norm": np.linalg.norm(ndarray).item()
235
+ }
236
+ except Exception as e:
237
+ # 决定打印内容或切片
238
+ logger.warning(f"Error analyzing ndarray stats: {e}")
239
+
240
+ # 最后一次性更新
241
+ ndarray_json.update(stats)
242
+
235
243
  return ndarray_json
236
244
 
237
245
  @staticmethod
@@ -248,7 +256,7 @@ class BaseDataProcessor:
248
256
 
249
257
  @classmethod
250
258
  def get_special_types(cls):
251
- return cls.special_type
259
+ return cls.builtin_type + cls.np_type
252
260
 
253
261
  @classmethod
254
262
  def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
@@ -303,6 +311,7 @@ class BaseDataProcessor:
303
311
 
304
312
  def real_hook_fn(grad):
305
313
  return wrap_hook_fn(grad)
314
+
306
315
  element.register_hook(real_hook_fn)
307
316
 
308
317
  def if_return_forward_new_output(self):
@@ -350,6 +359,8 @@ class BaseDataProcessor:
350
359
  return api_info_struct
351
360
 
352
361
  def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
362
+ if self.is_distributed_op(module):
363
+ module_input_output.update_output_with_args_and_kwargs()
353
364
  api_info_struct = {}
354
365
  # check whether data_mode contains forward or input
355
366
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
@@ -427,6 +438,7 @@ class BaseDataProcessor:
427
438
  api_info_struct = {}
428
439
  self.save_name = name + Const.SEP + param_name
429
440
  data_info = self.analyze_element(grad)
441
+ self.save_name = None
430
442
  grad_info_dict = {param_name: [data_info]}
431
443
  api_info_struct[name] = grad_info_dict
432
444
  return api_info_struct
@@ -435,10 +447,10 @@ class BaseDataProcessor:
435
447
  file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
436
448
  if self.save_name is not None:
437
449
  dump_data_name = (self.save_name + file_format)
438
- self.save_name = None
439
450
  else:
440
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
441
- suffix + file_format)
451
+ suffix_with_seq = (Const.SEP + suffix) if suffix else ""
452
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq +
453
+ file_format)
442
454
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
443
455
  return dump_data_name, file_path
444
456
 
@@ -447,23 +459,32 @@ class BaseDataProcessor:
447
459
 
448
460
  def analyze_debug_forward(self, variable, name_with_count):
449
461
  self.current_api_or_module_name = name_with_count
450
- self.api_data_category = Const.TENSOR
451
- # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
462
+ self.api_data_category = Const.DEBUG
463
+ # these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt
452
464
  data_info = self.analyze_element(variable)
453
465
  return data_info
454
466
 
455
- def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
467
+ def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure):
456
468
  def hook_fn(grad, indexes):
457
469
  suffix = Const.SEP.join([str(index) for index in indexes])
458
- self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
470
+ suffix_with_sep = (Const.SEP + suffix) if suffix else ""
471
+ self.save_name = grad_name_with_count_category + suffix_with_sep
459
472
  grad_data_info = self.analyze_element(grad)
460
473
  self.save_name = None
461
- full_index = [grad_name_with_count] + indexes
474
+ full_index = [grad_name_with_count_category] + indexes
462
475
  try:
463
476
  self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
464
477
  except (ValueError, IndexError) as e:
465
- logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
466
- f"skip current recording, detailed infomation: {e}")
478
+ logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable,"
479
+ f"skip current recording, detailed information: {e}")
467
480
  return grad
481
+
468
482
  wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
469
- self.recursive_apply_transform(variable, wrap_register_hook_single_element)
483
+ self.recursive_apply_transform(variable, wrap_register_hook_single_element)
484
+
485
+ def _analyze_and_save_ndarray(self, ndarray, suffix):
486
+ dump_data_name, file_path = self.get_save_file_path(suffix)
487
+ save_npy(ndarray, file_path)
488
+ ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix)
489
+ ndarray_json.update({"data_name": dump_data_name})
490
+ return ndarray_json