mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from msprobe.core.common.runtime import Runtime
18
+ from msprobe.core.common.utils import Const
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+
23
+ class ATTLManager:
24
+ def __init__(self, config):
25
+ self.config = config
26
+ self.attl = None
27
+
28
+ def attl_init(self):
29
+ if self.config.online_run_ut:
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
31
+ attl_config = ATTLConfig(is_benchmark_device=False,
32
+ connect_ip=self.config.host,
33
+ connect_port=self.config.port,
34
+ nfs_path=self.config.nfs_path,
35
+ tls_path=self.config.tls_path)
36
+ need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
37
+ self.attl = ATTL('npu', attl_config, need_dump=need_dump)
38
+ if self.config.nfs_path:
39
+ self.attl.upload("start")
40
+
41
+ def attl_send(self, name, args, kwargs, output):
42
+ api_data = ApiData(
43
+ name[:-len(Const.FORWARD_NAME_SUFFIX)],
44
+ args,
45
+ kwargs,
46
+ output,
47
+ Runtime.current_iter,
48
+ Runtime.current_rank
49
+ )
50
+ logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
51
+ api_type, _, _ = api_data.name.split(Const.SEP)
52
+ if api_type in [Const.DISTRIBUTED]:
53
+ logger.info(f"api {api_data.name} is not supported, skip")
54
+ return
55
+ if self.config.nfs_path:
56
+ self.attl.upload(api_data)
57
+ else:
58
+ self.attl.send(api_data)
59
+
60
+ def attl_stop(self):
61
+ if self.config.nfs_path:
62
+ self.attl.upload("end")
63
+ elif self.attl.socket_manager is not None:
64
+ logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
65
+ self.attl.socket_manager.send_stop_signal()
@@ -29,6 +29,8 @@ def softmax_func(x, axis=None):
29
29
 
30
30
  def npu_moe_gating_top_k_softmax(x, finished_optional, k):
31
31
  input_dtype = x.dtype
32
+ if x.dim() < 1:
33
+ raise ValueError("Input x must have at least 1 dimensions.")
32
34
  num_expert = x.shape[-1]
33
35
  softmax = softmax_func(x, -1)
34
36
  softmax = softmax.to(input_dtype)
@@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
36
38
  expert_idx = expert_idx[:, :k]
37
39
  y = torch.gather(softmax, index=expert_idx, dim=-1)
38
40
  if finished_optional is not None:
41
+ if finished_optional.dim() < 1:
42
+ raise ValueError("Finished_optional must have at least 1 dimensions.")
39
43
  finished_optional = finished_optional.view(finished_optional.shape[0], 1)
40
44
  finished_optional = finished_optional.expand(-1, k)
41
45
  expert_idx = torch.where(finished_optional, num_expert, expert_idx)
46
+ if y.dim() < 2:
47
+ raise ValueError("Variable y must have at least 2 dimensions.")
42
48
  row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
43
49
 
44
50
  return y, expert_idx, row_idx
@@ -117,6 +117,12 @@ def fusion_attention_forward(forward_params):
117
117
  pse = forward_params.pse
118
118
  scale = forward_params.scale
119
119
  keep_prob = forward_params.keep_prob
120
+
121
+ # 除零风险拦截:keep_prob 为 0 时会导致除零错误
122
+ if keep_prob == 0:
123
+ raise ValueError("fusion_attention_forward: keep_prob cannot be zero to avoid division by zero.")
124
+
125
+
120
126
  qk = calculate_qk(q, k, atten_mask, pse, scale)
121
127
  softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
122
128
  if drop_mask is None or len(drop_mask.shape) == 0:
@@ -137,6 +143,11 @@ def fusion_attention_backward(backward_params):
137
143
  pse = backward_params.pse
138
144
  scale = backward_params.scale
139
145
  keep_prob = backward_params.keep_prob
146
+
147
+ # 除零风险拦截:keep_prob 为 0 时会导致除零错误
148
+ if keep_prob == 0:
149
+ raise ValueError("fusion_attention_backward: keep_prob cannot be zero to avoid division by zero.")
150
+
140
151
  dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
141
152
  if drop_mask is None or len(drop_mask.shape) == 0:
142
153
  drop_res = softmax_res.permute(0, 1, 3, 2)
@@ -164,23 +175,35 @@ def parse_bsnd_args(query, key, head_num, input_layout):
164
175
  if input_layout == "BSH":
165
176
  b, s1, h1 = query.shape
166
177
  _, s2, h2 = key.shape
178
+ if n1 == 0:
179
+ raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
167
180
  d = h1 // n1
181
+ if d == 0:
182
+ raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
168
183
  n2 = h2 // d
169
184
  elif input_layout == "SBH":
170
185
  s1, b, h1 = query.shape
171
186
  s2, _, h2 = key.shape
187
+ if n1 == 0:
188
+ raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
172
189
  d = h1 // n1
190
+ if d == 0:
191
+ raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
173
192
  n2 = h2 // d
174
193
  elif input_layout == "BSND":
175
194
  b, s1, n1, d = query.shape
176
195
  _, s2, n2, _ = key.shape
177
196
  h1 = n1 * d
178
197
  h2 = n2 * d
198
+ if d == 0:
199
+ raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
179
200
  elif input_layout == "BNSD":
180
201
  b, n1, s1, d = query.shape
181
202
  _, n2, s2, _ = key.shape
182
203
  h1 = n1 * d
183
204
  h2 = n2 * d
205
+ if d == 0:
206
+ raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
184
207
  except Exception as e:
185
208
  raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
186
209
 
@@ -446,6 +469,8 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
446
469
  input_layout = get_input_layout(*args, **kwargs)
447
470
 
448
471
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
472
+ if d == 0:
473
+ raise ValueError("npu_fusion_attention_forward_patch: head dimension (d) is zero, division by zero risk.")
449
474
  if n1 == n2 and s1 == s2:
450
475
  logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
451
476
  else:
@@ -478,6 +503,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
478
503
  raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
479
504
 
480
505
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
506
+ if d == 0:
507
+ raise ValueError("npu_fusion_attention_backward_patch: head dimension (d) is zero, division by zero risk.")
481
508
  if n1 == n2 and s1 == s2:
482
509
  logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
483
510
  else:
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,11 +24,12 @@ from functools import wraps
24
24
  import numpy as np
25
25
  import torch
26
26
  import torch.distributed as dist
27
+
27
28
  from msprobe.core.common.exceptions import DistributedNotInitializedError
28
29
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
29
30
  check_file_or_directory_path, check_path_before_create, FileOpen)
30
31
  from msprobe.core.common.log import logger
31
- from msprobe.core.common.utils import check_seed_all
32
+ from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
32
33
  from packaging import version
33
34
 
34
35
  try:
@@ -38,7 +39,9 @@ except ImportError:
38
39
  else:
39
40
  is_gpu = False
40
41
 
42
+
41
43
  torch_without_guard_version = torch.__version__ >= '2.1'
44
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
42
45
 
43
46
  if not is_gpu and not torch_without_guard_version:
44
47
  from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
@@ -57,7 +60,7 @@ def parameter_adapter(func):
57
60
 
58
61
  @wraps(func)
59
62
  def inner(self, *args, **kwargs):
60
- if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
63
+ if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
61
64
  input_tensor = args[0]
62
65
  indices = args[1]
63
66
  if indices.dtype == torch.uint8:
@@ -77,7 +80,7 @@ def parameter_adapter(func):
77
80
  else:
78
81
  res = [input_tensor[tensor_index] for tensor_index in indices]
79
82
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
80
- if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
83
+ if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
81
84
  return False
82
85
  return func(self, *args, **kwargs)
83
86
 
@@ -261,6 +264,10 @@ class Const:
261
264
  NPU = 'NPU'
262
265
  DISTRIBUTED = 'Distributed'
263
266
 
267
+ HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
268
+ FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
269
+ FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
270
+
264
271
  RAISE_PRECISION = {
265
272
  torch.float16: torch.float32,
266
273
  torch.bfloat16: torch.float32,
@@ -309,14 +316,14 @@ def print_rank_0(message):
309
316
  logger.info(message)
310
317
 
311
318
 
312
- def load_pt(pt_path, to_cpu=False):
319
+ def load_pt(pt_path, to_cpu=False, weights_only=True):
313
320
  pt_path = os.path.realpath(pt_path)
314
321
  check_file_or_directory_path(pt_path)
315
322
  try:
316
323
  if to_cpu:
317
- pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
324
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only)
318
325
  else:
319
- pt = torch.load(pt_path, weights_only=True)
326
+ pt = torch.load(pt_path, weights_only=weights_only)
320
327
  except Exception as e:
321
328
  raise RuntimeError(f"load pt file {pt_path} failed") from e
322
329
  return pt
@@ -391,7 +398,7 @@ def save_api_data(api_data):
391
398
  io_buff = io.BytesIO()
392
399
  torch.save(api_data, io_buff)
393
400
  except Exception as e:
394
- raise RuntimeError(f"save api_data to io_buff failed") from e
401
+ raise RuntimeError("save api_data to io_buff failed") from e
395
402
  return io_buff
396
403
 
397
404
 
@@ -401,7 +408,7 @@ def load_api_data(api_data_bytes):
401
408
  buffer = io.BytesIO(api_data_bytes)
402
409
  buffer = torch.load(buffer, map_location="cpu")
403
410
  except Exception as e:
404
- raise RuntimeError(f"load api_data from bytes failed") from e
411
+ raise RuntimeError("load api_data from bytes failed") from e
405
412
  return buffer
406
413
 
407
414
 
@@ -419,7 +426,11 @@ def is_recomputation():
419
426
  bool: True if in the re-computation phase, False otherwise.
420
427
  """
421
428
  backward_function_indices = []
422
- call_stack = inspect.stack()
429
+ try:
430
+ call_stack = inspect.stack()
431
+ except Exception as e:
432
+ logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
433
+ return False
423
434
 
424
435
  # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
425
436
  for frame_info in call_stack:
@@ -449,9 +460,11 @@ def is_recomputation():
449
460
 
450
461
  def check_save_param(variable, name, save_backward):
451
462
  # try catch this api to skip invalid call
452
- if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)):
463
+ valid_data_types = (torch.Tensor, int, float, str)
464
+ if not is_save_variable_valid(variable, valid_data_types):
465
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
453
466
  logger.warning("PrecisionDebugger.save variable type not valid, "
454
- "should be one of list, dict, torch.Tensor, int, float or string. "
467
+ f"should be one of {valid_data_types_with_nested_types}"
455
468
  "Skip current save process.")
456
469
  raise ValueError
457
470
  if not isinstance(name, str):
@@ -466,10 +479,31 @@ def check_save_param(variable, name, save_backward):
466
479
  raise ValueError
467
480
 
468
481
 
469
- def replace_last_occurrence(text, old, new):
470
- if text is None:
471
- return text
472
- index = text.rfind(old)
473
- if index != -1:
474
- return text[:index] + text[index:].replace(old, new, 1)
475
- return text
482
+ def is_torch_nn_module(variable):
483
+ return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
484
+
485
+
486
+ def is_hifloat8_tensor(tensor):
487
+ if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
488
+ return True
489
+ return False
490
+
491
+
492
+ def is_float8_tensor(tensor):
493
+ if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
494
+ return True
495
+ return is_hifloat8_tensor(tensor)
496
+
497
+
498
+ def register_forward_pre_hook(module, forward_pre_hook):
499
+ if torch_version_above_or_equal_2:
500
+ module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
501
+ else:
502
+ module.register_forward_pre_hook(forward_pre_hook)
503
+
504
+
505
+ def register_forward_hook(module, forward_hook):
506
+ if torch_version_above_or_equal_2:
507
+ module.register_forward_hook(forward_hook, with_kwargs=True)
508
+ else:
509
+ module.register_forward_hook(forward_hook)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,41 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
-
18
- from msprobe.core.common.exceptions import FileCheckException
19
- from msprobe.core.common.file_utils import create_directory
20
- from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
21
- set_dump_path
22
- from msprobe.core.compare.acc_compare import ModeConfig
23
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
24
- from msprobe.pytorch.common.log import logger
25
- from msprobe.pytorch.compare.pt_compare import PTComparator, compare
16
+ from msprobe.core.compare.utils import compare_distributed_inner
17
+ from msprobe.pytorch.compare.pt_compare import compare
26
18
 
27
19
 
28
20
  def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
29
- if kwargs.get("suffix"):
30
- logger.error("Argument 'suffix' is not supported for compare_distributed.")
31
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
32
- is_print_compare_log = kwargs.get("is_print_compare_log", True)
33
- # get the ranks and match by order
34
- npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
35
- bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
36
- if len(npu_ranks) != len(bench_ranks):
37
- logger.error(
38
- "The number of ranks in the two runs are different. "
39
- "Unable to match the ranks. "
40
- "Please use another folder to compare or use compare() api and manually match the ranks.")
41
- raise CompareException(CompareException.INVALID_PATH_ERROR)
42
- for nr, br in zip(npu_ranks, bench_ranks):
43
- npu_data_dir = os.path.join(npu_dump_dir, nr)
44
- bench_data_dir = os.path.join(bench_dump_dir, br)
45
- npu_path = extract_json(npu_data_dir, stack_json=False)
46
- bench_path = extract_json(bench_data_dir, stack_json=False)
47
-
48
- dump_result_param = {
49
- "npu_json_path": npu_path,
50
- "bench_json_path": bench_path,
51
- "is_print_compare_log": is_print_compare_log
52
- }
53
- compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
21
+ compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare, **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,92 +13,21 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os.path
16
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
17
+ from msprobe.pytorch.compare.utils import read_pt_data
17
18
 
18
- import torch
19
19
 
20
- from msprobe.core.common.const import FileCheckConst
21
- from msprobe.core.common.exceptions import FileCheckException
22
- from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
23
- from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
24
- set_dump_path
25
- from msprobe.core.compare.acc_compare import Comparator, ModeConfig
26
- from msprobe.core.compare.utils import set_stack_json_path
27
- from msprobe.pytorch.common.log import logger
28
- from msprobe.pytorch.common.utils import load_pt
29
-
30
-
31
- class PTComparator(Comparator):
32
- def __init__(self, mode_config, data_mapping=None):
33
- super().__init__(mode_config)
34
-
35
- self.stack_mode = mode_config.stack_mode
36
- self.auto_analyze = mode_config.auto_analyze
37
- self.fuzzy_match = mode_config.fuzzy_match
38
- self.dump_mode = mode_config.dump_mode
39
-
40
- self.frame_name = PTComparator.__name__
41
- self.data_mapping = data_mapping
42
- if isinstance(self.data_mapping, str) or self.data_mapping is None:
43
- self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
44
- elif isinstance(self.data_mapping, dict):
45
- self.data_mapping_dict = self.data_mapping
46
- else:
47
- raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
48
- f"{type(self.data_mapping)}")
49
-
50
- @staticmethod
51
- def load_mapping_file(mapping_file):
52
- if isinstance(mapping_file, str):
53
- mapping_dict = load_yaml(mapping_file)
54
- else:
55
- mapping_dict = {}
56
- return mapping_dict
57
-
58
- def read_npy_data(self, dir_path, file_name):
59
- if not file_name:
60
- return None
61
- data_path = os.path.join(dir_path, file_name)
62
- path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
63
- FileCheckConst.PT_SUFFIX, False)
64
- data_path = path_checker.common_check()
65
- try:
66
- # detach because numpy can not process gradient information
67
- data_value = load_pt(data_path, to_cpu=True).detach()
68
- except RuntimeError as e:
69
- # 这里捕获 load_pt 中抛出的异常
70
- logger.error(f"Failed to load the .pt file at {data_path}.")
71
- raise CompareException(CompareException.INVALID_FILE_ERROR) from e
72
- except AttributeError as e:
73
- # 这里捕获 detach 方法抛出的异常
74
- logger.error(f"Failed to detach the loaded tensor.")
75
- raise CompareException(CompareException.DETACH_ERROR) from e
76
- if data_value.dtype == torch.bfloat16:
77
- data_value = data_value.to(torch.float32)
78
- data_value = data_value.numpy()
79
- return data_value
20
+ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple:
21
+ n_value = read_pt_data(npu_dir, npu_data_name)
22
+ b_value = read_pt_data(bench_dir, bench_data_name)
23
+ return n_value, b_value
80
24
 
81
25
 
82
26
  def compare(input_param, output_path, **kwargs):
83
- try:
84
- auto_analyze = kwargs.get('auto_analyze', True)
85
- fuzzy_match = kwargs.get('fuzzy_match', False)
86
- data_mapping = kwargs.get('data_mapping', None)
87
- suffix = kwargs.get('suffix', '')
88
-
89
- set_dump_path(input_param)
90
- dump_mode = get_dump_mode(input_param)
91
- if "stack_json_path" in input_param:
92
- stack_mode = kwargs.get('stack_mode', False)
93
- else:
94
- stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
95
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
96
- create_directory(output_path)
97
- check_compare_param(input_param, output_path, dump_mode, stack_mode)
98
- except (CompareException, FileCheckException) as error:
99
- logger.error('Compare failed. Please check the arguments and do it again!')
100
- raise CompareException(error.code) from error
27
+ config = setup_comparison(input_param, output_path, **kwargs)
101
28
 
102
- mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
103
- pt_comparator = PTComparator(mode_config, data_mapping)
104
- pt_comparator.compare_core(input_param, output_path, suffix=suffix)
29
+ mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
30
+ config.dump_mode, config.compared_file_type)
31
+ mapping_config = MappingConfig(data_mapping=config.data_mapping)
32
+ pt_comparator = Comparator(read_real_data, mode_config, mapping_config)
33
+ pt_comparator.compare_core(input_param, output_path, suffix=config.suffix)
@@ -0,0 +1,47 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import torch
19
+
20
+ from msprobe.core.common.utils import logger, CompareException
21
+ from msprobe.core.common.file_utils import FileChecker, FileCheckConst
22
+ from msprobe.pytorch.common.utils import load_pt
23
+
24
+
25
+ def read_pt_data(dir_path, file_name):
26
+ if not file_name:
27
+ return None
28
+
29
+ data_path = os.path.join(dir_path, file_name)
30
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
31
+ FileCheckConst.PT_SUFFIX, False)
32
+ data_path = path_checker.common_check()
33
+ try:
34
+ # detach because numpy can not process gradient information
35
+ data_value = load_pt(data_path, to_cpu=True).detach()
36
+ except RuntimeError as e:
37
+ # 这里捕获 load_pt 中抛出的异常
38
+ logger.error(f"Failed to load the .pt file at {data_path}.")
39
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from e
40
+ except AttributeError as e:
41
+ # 这里捕获 detach 方法抛出的异常
42
+ logger.error(f"Failed to detach the loaded tensor.")
43
+ raise CompareException(CompareException.DETACH_ERROR) from e
44
+ if data_value.dtype == torch.bfloat16:
45
+ data_value = data_value.to(torch.float32)
46
+ data_value = data_value.numpy()
47
+ return data_value
@@ -13,11 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import torch
17
-
18
16
  from msprobe.core.common.const import Const
19
17
  from msprobe.core.common.exceptions import MsprobeException
20
18
  from msprobe.pytorch.common.log import logger
19
+ from msprobe.pytorch.common.utils import is_torch_nn_module
21
20
 
22
21
 
23
22
  class DebuggerConfig:
@@ -60,6 +59,7 @@ class DebuggerConfig:
60
59
  if isinstance(task_config.online_run_ut_recompute, bool) else False
61
60
 
62
61
  self.check()
62
+ self._check_statistics_config(task_config)
63
63
 
64
64
  if self.level == Const.LEVEL_L2:
65
65
  self.is_backward_kernel_dump = False
@@ -78,10 +78,13 @@ class DebuggerConfig:
78
78
  if not isinstance(self.async_dump, bool):
79
79
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
80
80
  f"The parameters async_dump should be bool.")
81
- if self.async_dump and self.task == Const.TENSOR and not self.list:
82
- raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
83
- f"The parameters async_dump is true in tensor task, the parameters list cannot be "
84
- f"empty.")
81
+ if self.async_dump and self.task == Const.TENSOR:
82
+ if self.level == Const.LEVEL_DEBUG:
83
+ self.list = [] # async_dump + debug level case ignore list
84
+ if not self.list and self.level != Const.LEVEL_DEBUG:
85
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
86
+ f"The parameters async_dump is true in tensor task, the parameters list cannot be "
87
+ f"empty.")
85
88
  if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
86
89
  logger.warning_on_rank_0(
87
90
  f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
@@ -93,25 +96,24 @@ class DebuggerConfig:
93
96
  self.check_kwargs()
94
97
  return True
95
98
 
96
- def check_model(self, instance, start_model):
97
- if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
98
- if instance.model is not None or start_model is not None:
99
- logger.info_on_rank_0(
100
- f"The current level is not L0 or mix level, so the model parameters will not be used.")
99
+ def check_model(self, instance, start_model, token_range=None):
100
+ instance.model = start_model if start_model is not None else instance.model
101
+ if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
101
102
  return
102
- if start_model is None and instance.model is None:
103
+
104
+ if instance.model is None:
103
105
  logger.error_on_rank_0(
104
- f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.")
106
+ f"For level {self.level} or non-empty token_range, "
107
+ f"PrecisionDebugger or start interface must receive a 'model' parameter.")
105
108
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
106
109
 
107
- instance.model = start_model if start_model is not None else instance.model
108
- if isinstance(instance.model, torch.nn.Module):
110
+ if is_torch_nn_module(instance.model):
109
111
  return
110
112
 
111
113
  error_model = None
112
114
  if isinstance(instance.model, (list, tuple)):
113
115
  for model in instance.model:
114
- if not isinstance(model, torch.nn.Module):
116
+ if not is_torch_nn_module(model):
115
117
  error_model = model
116
118
  break
117
119
  else:
@@ -119,7 +121,7 @@ class DebuggerConfig:
119
121
 
120
122
  if error_model is not None:
121
123
  error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
122
- f"type, currently there is a {type(error_model)} type.")
124
+ f"type, currently there is an unsupported {type(error_model)} type.")
123
125
  raise MsprobeException(
124
126
  MsprobeException.INVALID_PARAM_ERROR, error_info)
125
127
 
@@ -130,8 +132,23 @@ class DebuggerConfig:
130
132
  if not self.list or len(self.list) != 1:
131
133
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
132
134
  f"When level is set to L2, the list must be configured as a list with one api name.")
135
+ if self.task != Const.TENSOR:
136
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
137
+ f"When level is set to L2, the task must be set to tensor.")
138
+
133
139
  api_name = self.list[0]
134
140
  if api_name.endswith(Const.BACKWARD):
135
141
  self.is_backward_kernel_dump = True
136
142
  api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
137
143
  self.list.append(api_forward_name)
144
+
145
+ def _check_statistics_config(self, task_config):
146
+ if self.task != Const.STATISTICS:
147
+ return
148
+ self.tensor_list = []
149
+ if not hasattr(task_config, "tensor_list"):
150
+ return
151
+ if self.level == Const.LEVEL_DEBUG and task_config.tensor_list:
152
+ logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
153
+ return
154
+ self.tensor_list = task_config.tensor_list