mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,19 +13,63 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
16
17
  import os
17
18
  import random
19
+ import types
18
20
 
19
21
  import mindspore as ms
20
-
21
22
  from mindspore import ops
23
+ from mindspore.common.jit_config import JitConfig
22
24
  from mindspore.mint import nn
23
25
 
26
+ from msprobe.core.common.const import Const
27
+ from msprobe.core.common.decorator import recursion_depth_decorator
24
28
  from msprobe.core.common.exceptions import DistributedNotInitializedError
25
29
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
26
30
  from msprobe.core.common.log import logger
27
- from msprobe.core.common.const import Const
28
- from msprobe.core.common.utils import CompareException, check_seed_all
31
+ from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
32
+ from msprobe.mindspore.common.const import Const as MsConst
33
+
34
+ try:
35
+ from mindspore._c_expression import _set_init_iter
36
+ except ImportError:
37
+ enable_dynamic_kbyk_dump = False
38
+ else:
39
+ enable_dynamic_kbyk_dump = True
40
+
41
+ mindtorch_check_result = None
42
+ register_backward_hook_functions = {}
43
+ kwargs_exist_in_forward_hook = None
44
+
45
+
46
+ class MsprobeStep(ms.train.Callback):
47
+ def __init__(self, debugger):
48
+ super(MsprobeStep, self).__init__()
49
+ self.debugger = debugger
50
+
51
+ def on_train_begin(self, run_context):
52
+ self.debugger.start()
53
+ if enable_dynamic_kbyk_dump:
54
+ _set_init_iter(0)
55
+
56
+ def on_train_step_begin(self, run_context):
57
+ self.debugger.start()
58
+
59
+ def on_train_step_end(self, run_context):
60
+ self.debugger.stop()
61
+ self.debugger.step()
62
+
63
+
64
+ class MsprobeInitStep(ms.train.Callback):
65
+ def on_train_begin(self, run_context):
66
+ try:
67
+ from ms._c_expression import _set_init_iter
68
+ except ImportError:
69
+ logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
70
+ return
71
+ cb_params = run_context.original_args()
72
+ _set_init_iter(cb_params.cur_step_num)
29
73
 
30
74
 
31
75
  def get_rank_if_initialized():
@@ -58,8 +102,8 @@ def convert_to_int(value):
58
102
 
59
103
 
60
104
  def clean_input_kwargs(cell):
61
- if hasattr(cell, 'input_kwargs'):
62
- del cell.input_kwargs
105
+ if hasattr(cell, 'msprobe_input_kwargs'):
106
+ del cell.msprobe_input_kwargs
63
107
 
64
108
 
65
109
  def list_lowest_level_directories(root_dir):
@@ -93,20 +137,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
93
137
  remove_dropout()
94
138
 
95
139
 
96
- class MsprobeStep(ms.train.Callback):
97
-
98
- def __init__(self, debugger):
99
- super(MsprobeStep, self).__init__()
100
- self.debugger = debugger
101
-
102
- def on_train_step_begin(self, run_context):
103
- self.debugger.start()
104
-
105
- def on_train_step_end(self, run_context):
106
- self.debugger.stop()
107
- self.debugger.step()
108
-
109
-
110
140
  class Dropout(ops.Dropout):
111
141
  def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
112
142
  super().__init__(1., seed0, seed1)
@@ -142,9 +172,6 @@ def remove_dropout():
142
172
  nn.functional.dropout = dropout_ext
143
173
 
144
174
 
145
- mindtorch_check_result = None
146
-
147
-
148
175
  def is_mindtorch():
149
176
  global mindtorch_check_result
150
177
  if mindtorch_check_result is None:
@@ -159,17 +186,17 @@ def is_mindtorch():
159
186
  return mindtorch_check_result
160
187
 
161
188
 
162
- register_backward_hook_functions = {}
163
-
164
-
165
189
  def set_register_backward_hook_functions():
166
190
  global register_backward_hook_functions
191
+ if register_backward_hook_functions:
192
+ return
193
+
167
194
  if is_mindtorch():
168
195
  import torch
169
196
  from msprobe.mindspore.mindtorch import (_call_impl,
170
197
  register_full_backward_pre_hook,
171
198
  register_full_backward_hook)
172
- if not hasattr(torch, "register_full_backward_hook"):
199
+ if not hasattr(torch.nn.Module, "register_full_backward_hook"):
173
200
  setattr(torch.nn.Module, "_call_impl", _call_impl)
174
201
  setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
175
202
  setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
@@ -182,9 +209,11 @@ def set_register_backward_hook_functions():
182
209
 
183
210
  def check_save_param(variable, name, save_backward):
184
211
  # try catch this api to skip invalid call
185
- if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)):
212
+ valid_data_types = (ms.Tensor, int, float, str)
213
+ if not is_save_variable_valid(variable, valid_data_types):
214
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
186
215
  logger.warning("PrecisionDebugger.save variable type not valid, "
187
- "should be one of list, dict, ms.Tensor, int, float or string. "
216
+ f"should be one of {valid_data_types_with_nested_types}"
188
217
  "Skip current save process.")
189
218
  raise ValueError
190
219
  if not isinstance(name, str):
@@ -196,4 +225,103 @@ def check_save_param(variable, name, save_backward):
196
225
  logger.warning("PrecisionDebugger.save_backward name not valid, "
197
226
  "should be bool. "
198
227
  "Skip current save process.")
199
- raise ValueError
228
+ raise ValueError
229
+
230
+
231
+ def is_graph_mode_cell_dump_allowed(config):
232
+ if config.task not in [Const.TENSOR, Const.STATISTICS] or is_mindtorch() or not hasattr(ops, 'DumpGradient'):
233
+ return False
234
+ valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX]
235
+ if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE:
236
+ return True
237
+ return config.level == MsConst.CELL or config.level == Const.LEVEL_L0
238
+
239
+
240
+ @recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit')
241
+ def is_decorated_by_jit(func):
242
+ closure = getattr(func, '__closure__', [])
243
+ if closure:
244
+ for obj in closure:
245
+ if isinstance(obj.cell_contents, JitConfig):
246
+ return True
247
+ elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'):
248
+ if is_decorated_by_jit(obj.cell_contents):
249
+ return True
250
+ return False
251
+
252
+
253
+ @recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
254
+ def get_cells_and_names(model, cells_set=None, name_prefix=''):
255
+ cells_set = cells_set if cells_set else set()
256
+ if model in cells_set:
257
+ return
258
+
259
+ cells_set.add(model)
260
+ jit_decorated = is_decorated_by_jit(model.construct)
261
+ yield name_prefix, model, jit_decorated
262
+ if jit_decorated:
263
+ return
264
+
265
+ children_cells = getattr(model, '_cells')
266
+ for name, cell in children_cells.items():
267
+ if cell:
268
+ cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
269
+ jit_decorated = is_decorated_by_jit(model.construct)
270
+ if jit_decorated:
271
+ yield cells_name_prefix, cell, jit_decorated
272
+ else:
273
+ for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
274
+ yield ele
275
+
276
+
277
+ def get_cells_and_names_with_index(models):
278
+ cells_with_index_in_pynative_mode = {}
279
+ cells_with_index_in_graph_mode = {}
280
+
281
+ def distinguish_cells(cells):
282
+ cells_in_pynative_mode = []
283
+ cells_in_graph_mode = []
284
+ for name, cell, jit_decorated in cells:
285
+ if jit_decorated:
286
+ cells_in_graph_mode.append((name, cell))
287
+ else:
288
+ cells_in_pynative_mode.append((name, cell))
289
+ return cells_in_pynative_mode, cells_in_graph_mode
290
+
291
+ if is_mindtorch():
292
+ if isinstance(models, (list, tuple)):
293
+ for index, model in enumerate(models):
294
+ cells_with_index_in_pynative_mode[str(index)] = model.named_modules()
295
+ else:
296
+ cells_with_index_in_pynative_mode["-1"] = models.named_modules()
297
+ else:
298
+ if isinstance(models, (list, tuple)):
299
+ for index, model in enumerate(models):
300
+ cells = get_cells_and_names(model)
301
+ cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
302
+ cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode
303
+ cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode
304
+ else:
305
+ cells = get_cells_and_names(models)
306
+ cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
307
+ cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode
308
+ cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode
309
+
310
+ return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode
311
+
312
+
313
+ def has_kwargs_in_forward_hook():
314
+ global kwargs_exist_in_forward_hook
315
+
316
+ if kwargs_exist_in_forward_hook is None:
317
+ if is_mindtorch():
318
+ kwargs_exist_in_forward_hook = True
319
+ return kwargs_exist_in_forward_hook
320
+
321
+ try:
322
+ func_params = inspect.signature(nn.Cell.register_forward_hook).parameters
323
+ kwargs_exist_in_forward_hook = 'with_kwargs' in func_params
324
+ except Exception:
325
+ kwargs_exist_in_forward_hook = False
326
+
327
+ return kwargs_exist_in_forward_hook
@@ -0,0 +1,382 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.import functools
15
+
16
+ import os
17
+ import multiprocessing
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Tuple, Optional, Any
20
+ from concurrent.futures import ProcessPoolExecutor
21
+ from functools import partial
22
+ from pathlib import Path
23
+
24
+ import pandas as pd
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import CompareException
30
+ from msprobe.core.common.exceptions import FileCheckException
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
32
+ check_path_before_create, load_npy
33
+ from msprobe.core.common.const import CompareConst, FileCheckConst
34
+ from msprobe.core.compare.npy_compare import compare_ops_apply
35
+ from msprobe.core.compare.multiprocessing_compute import check_accuracy
36
+
37
+
38
+ def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
39
+ """
40
+ 高级目录比对函数,完全镜像输入目录结构
41
+
42
+ Args:
43
+ input_params: 包含npu_path和bench_path的字典
44
+ output_dir: 输出根目录
45
+
46
+ Returns:
47
+ 当输入目录是平铺npy文件时返回DataFrame,否则返回None
48
+ """
49
+ npu_root = Path(input_params.get('npu_path'))
50
+ bench_root = Path(input_params.get('bench_path'))
51
+ name_map_dict = input_params.get('map_dict', {})
52
+ file_tree = build_mirror_file_tree(npu_root, bench_root)
53
+
54
+ # 处理文件比对
55
+ with ProcessPoolExecutor() as executor:
56
+ results = list(tqdm(
57
+ executor.map(
58
+ partial(process_directory_pair, name_map_dict=name_map_dict, output_dir=output_dir),
59
+ file_tree.items()
60
+ ),
61
+ total=len(file_tree),
62
+ desc="Processing directories"
63
+ ))
64
+ return
65
+
66
+
67
+ def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str):
68
+ """
69
+ 处理一个目录对
70
+
71
+ Args:
72
+ item: (相对路径, (npu目录, bench目录))元组
73
+ output_dir: 输出根目录
74
+
75
+ Returns:
76
+ 比对结果的DataFrame(仅平铺结构时返回)
77
+ """
78
+ rel_path, (npu_dir, bench_dir) = item
79
+
80
+ # 创建镜像输出目录
81
+ output_path = Path(output_dir) / rel_path
82
+ create_directory(output_path)
83
+
84
+ # 生成文件映射
85
+ npu_files = find_npy_files(npu_dir)
86
+ bench_files = find_npy_files(bench_dir)
87
+ map_dict = generate_map_dict(npu_files, bench_files, name_map_dict)
88
+
89
+ if not map_dict:
90
+ logger.warning(f"No file pairs found in {rel_path}")
91
+ return None
92
+
93
+ # 执行比对
94
+ result_df = do_multi_process(process_chunk, map_dict)
95
+ check_path_before_create(output_path)
96
+ # 保存结果
97
+ result_path = os.path.join(output_path, 'result.csv')
98
+ write_df_to_csv(result_df, result_path)
99
+ logger.info(f"Results saved to {result_path}")
100
+ return None
101
+
102
+
103
+ def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]:
104
+ """
105
+ 构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组
106
+
107
+ Args:
108
+ npu_root: NPU数据根目录
109
+ bench_root: 基准数据根目录
110
+
111
+ Returns:
112
+ 文件树字典
113
+ """
114
+ file_tree = {}
115
+
116
+ # 遍历NPU目录构建树结构
117
+ for npu_path in npu_root.rglob('*.npy'):
118
+ dir_path = npu_path.relative_to(npu_root).parent
119
+ npu_dir_pair = os.path.join(npu_root, dir_path)
120
+ bench_dir_pair = os.path.join(bench_root, dir_path)
121
+ try:
122
+ check_file_or_directory_path(bench_dir_pair, isdir=True)
123
+ except FileCheckException:
124
+ continue
125
+ # 添加到文件树
126
+ if dir_path not in file_tree:
127
+ file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
128
+
129
+ return file_tree
130
+
131
+
132
+ def find_npy_files(directory):
133
+ npy_files_dict = {}
134
+ for root, _, files in os.walk(directory):
135
+ for file in files:
136
+ if file.endswith(".npy"):
137
+ # 分割文件名并去掉最后两个元素
138
+ file_name = file.split('_')
139
+ if len(file_name) < 2:
140
+ continue
141
+ key = '_'.join(file_name[:-2])
142
+ # 文件的完整路径
143
+ value = os.path.join(root, file)
144
+ # 添加到字典中
145
+ if not npy_files_dict.get(key):
146
+ npy_files_dict[key] = []
147
+ npy_files_dict[key].append(value)
148
+ return npy_files_dict
149
+
150
+
151
+ def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
152
+ for k, npu_file_list in npu_file_dict.items():
153
+ bench_file_list = bench_file_dict.get(k)
154
+ if not bench_file_list and k in name_map_dict:
155
+ bench_file_list = bench_file_dict.get(name_map_dict.get(k))
156
+ bench_length = len(bench_file_list)
157
+ if not (bench_file_list and bench_length):
158
+ continue
159
+ result_dict = {}
160
+ for i, npu_file in enumerate(npu_file_list):
161
+ if i >= bench_length:
162
+ break
163
+ bench_file = bench_file_list[i]
164
+ result_dict[f"{k}_{i}"] = (npu_file, bench_file)
165
+ return result_dict
166
+
167
+
168
+ def do_multi_process(func, map_dict):
169
+ lock = multiprocessing.Manager().RLock()
170
+ result_len = len(map_dict)
171
+ process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
172
+ # every block size
173
+ df_chunk_size = result_len // process_num
174
+
175
+ # generate the same len of map_dict df
176
+ result_df = initialize_result_df(result_len)
177
+ if df_chunk_size > 0:
178
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
179
+ else:
180
+ df_chunks = [result_df]
181
+ process_num = 1
182
+ logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
183
+
184
+ # 分割字典
185
+ map_chunks = split_dict(map_dict, df_chunk_size)
186
+
187
+ # 创建结果列表和进程池
188
+ results = []
189
+ pool = multiprocessing.Pool(process_num)
190
+
191
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
192
+
193
+ def update_progress(size, progress_lock, extra_param=None):
194
+ with progress_lock:
195
+ progress_bar.update(size)
196
+
197
+ def err_call(args):
198
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
199
+ try:
200
+ pool.close()
201
+ except OSError as e:
202
+ logger.error(f'pool terminate failed: {str(e)}')
203
+ results = []
204
+ try:
205
+ # 提交任务到进程池
206
+ for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
207
+ start_idx = df_chunk_size * process_idx
208
+ result = pool.apply_async(
209
+ func,
210
+ args=(df_chunk, start_idx, map_chunk, lock),
211
+ error_callback=err_call,
212
+ callback=partial(update_progress, len(map_chunk), lock)
213
+ )
214
+ results.append(result)
215
+
216
+ final_results = [r.get() for r in results]
217
+ # 等待所有任务完成
218
+ pool.close()
219
+ pool.join()
220
+ return pd.concat(final_results, ignore_index=True)
221
+ except Exception as e:
222
+ logger.error(f"\nMain process error: {str(e)}")
223
+ pool.terminate()
224
+ return pd.DataFrame({})
225
+ finally:
226
+ pool.close()
227
+
228
+
229
+ def initialize_result_df(total_size):
230
+ """预分配结果DataFrame"""
231
+ columns = [
232
+ CompareConst.NAME,
233
+ CompareConst.NPU_DTYPE,
234
+ CompareConst.BENCH_DTYPE,
235
+ CompareConst.NPU_SHAPE,
236
+ CompareConst.BENCH_SHAPE,
237
+ CompareConst.COSINE,
238
+ CompareConst.EUC_DIST,
239
+ CompareConst.MAX_ABS_ERR,
240
+ CompareConst.MAX_RELATIVE_ERR,
241
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO,
242
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
243
+ CompareConst.NPU_MAX,
244
+ CompareConst.NPU_MIN,
245
+ CompareConst.NPU_MEAN,
246
+ CompareConst.NPU_NORM,
247
+ CompareConst.BENCH_MAX,
248
+ CompareConst.BENCH_MIN,
249
+ CompareConst.BENCH_MEAN,
250
+ CompareConst.BENCH_NORM,
251
+ CompareConst.ACCURACY,
252
+ CompareConst.ERROR_MESSAGE,
253
+ CompareConst.DATA_NAME
254
+ ]
255
+ return pd.DataFrame(index=range(total_size), columns=columns)
256
+
257
+
258
+ def split_dict(input_dict, chunk_size):
259
+ """将字典按指定chunk_size分割"""
260
+ items = list(input_dict.items())
261
+ if chunk_size > 0:
262
+ return [dict(items[i:i + chunk_size]) for i in range(0, len(items), chunk_size)]
263
+ return [input_dict]
264
+
265
+
266
+ def get_tensor_stats(tensor: np.ndarray) -> Tuple[float, float, float, float]:
267
+ """获取张量的统计信息"""
268
+ t_max = np.max(tensor)
269
+ t_min = np.min(tensor)
270
+ t_mean = np.mean(tensor)
271
+ t_l2norm = np.linalg.norm(tensor)
272
+ return t_max, t_min, t_mean, t_l2norm
273
+
274
+
275
+ def process_chunk(df, start_idx, map_chunk, lock):
276
+ """处理一个数据块"""
277
+ err_mess = []
278
+ results = []
279
+ for name, file_pair in map_chunk.items():
280
+ err_msg = ""
281
+ npu_file, bench_file = file_pair
282
+ n_value = load_npy(npu_file)
283
+ # if need to support cross frame b_value need to add load_pt
284
+ b_value = load_npy(bench_file)
285
+ error_flag = False
286
+
287
+ err_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
288
+ cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio = err_list
289
+ a_max, a_min, a_mean, a_l2norm = get_tensor_stats(n_value)
290
+ b_max, b_min, b_mean, b_l2norm = get_tensor_stats(b_value)
291
+ err_mess.append(err_msg)
292
+ # 使用示例
293
+ result = ComparisonResult(
294
+ name=name, # CompareConst.NAME
295
+ npu_dtype=n_value.dtype, # CompareConst.NPU_DTYPE
296
+ bench_dtype=b_value.dtype, # CompareConst.BENCH_DTYPE
297
+ npu_shape=n_value.shape, # CompareConst.NPU_SHAPE
298
+ bench_shape=b_value.shape, # CompareConst.BENCH_SHAPE
299
+ cosine=cos_sim, # CompareConst.COSINE
300
+ euc_dist=euc_dist, # CompareConst.EUC_DIST
301
+ max_abs_err=max_abs_err, # CompareConst.MAX_ABS_ERR
302
+ max_relative_err=max_relative_err, # CompareConst.MAX_RELATIVE_ERR
303
+ one_thousandth_err_ratio=one_thousand_err_ratio, # CompareConst.ONE_THOUSANDTH_ERR_RATIO
304
+ five_thousandth_err_ratio=five_thousand_err_ratio, # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
305
+ npu_max=a_max, # CompareConst.NPU_MAX
306
+ npu_min=a_min, # CompareConst.NPU_MIN
307
+ npu_mean=a_mean, # CompareConst.NPU_MEAN
308
+ npu_norm=a_l2norm, # CompareConst.NPU_NORM
309
+ bench_max=b_max, # CompareConst.BENCH_MAX
310
+ bench_min=b_min, # CompareConst.BENCH_MIN
311
+ bench_mean=b_mean, # CompareConst.BENCH_MEAN
312
+ bench_norm=b_l2norm, # CompareConst.BENCH_NORM
313
+ accuracy=check_accuracy(cos_sim, max_abs_err), # CompareConst.ACCURACY
314
+ error_message=err_msg, # CompareConst.ERROR_MESSAGE
315
+ data_name=[npu_file, bench_file] # CompareConst.DATA_NAME
316
+ )
317
+ results.append(result)
318
+ return _save_part_df(df, start_idx, results, lock)
319
+
320
+
321
+ @dataclass
322
+ class ComparisonResult:
323
+ name: str # CompareConst.NAME
324
+ npu_dtype: Any # CompareConst.NPU_DTYPE
325
+ bench_dtype: Any # CompareConst.BENCH_DTYPE
326
+ npu_shape: Tuple[int, ...] # CompareConst.NPU_SHAPE
327
+ bench_shape: Tuple[int, ...] # CompareConst.BENCH_SHAPE
328
+ cosine: float # Cons t.COSINE
329
+ euc_dist: float # CompareConst.EUC_DIST
330
+ max_abs_err: float # CompareConst.MAX_ABS_ERR
331
+ max_relative_err: float # CompareConst.MAX_RELATIVE_ERR
332
+ one_thousandth_err_ratio: float # CompareConst.ONE_THOUSANDTH_ERR_RATIO
333
+ five_thousandth_err_ratio: float # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
334
+ npu_max: float # CompareConst.NPU_MAX
335
+ npu_min: float # CompareConst.NPU_MIN
336
+ npu_mean: float # CompareConst.NPU_MEAN
337
+ npu_norm: float # CompareConst.NPU_NORM
338
+ bench_max: float # CompareConst.BENCH_MAX
339
+ bench_min: float # CompareConst.BENCH_MIN
340
+ bench_mean: float # CompareConst.BENCH_MEAN
341
+ bench_norm: float # CompareConst.BENCH_NORM
342
+ accuracy: bool # CompareConst.ACCURACY
343
+ error_message: str # CompareConst.ERROR_MESSAGE
344
+ data_name: List[str] # CompareConst.DATA_NAME
345
+
346
+
347
+ def _save_part_df(df, start_idx, results, lock):
348
+ lock.acquire()
349
+ try:
350
+ for i, result in enumerate(results):
351
+ process_index = i + start_idx
352
+ df.loc[process_index, CompareConst.NAME] = result.name
353
+ df.loc[process_index, CompareConst.NPU_DTYPE] = result.npu_dtype
354
+ df.loc[process_index, CompareConst.BENCH_DTYPE] = result.bench_dtype
355
+ df.loc[process_index, CompareConst.NPU_SHAPE] = str(result.npu_shape) # 通常将tuple转为字符串存储
356
+ df.loc[process_index, CompareConst.BENCH_SHAPE] = str(result.bench_shape)
357
+ df.loc[process_index, CompareConst.COSINE] = result.cosine
358
+ df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist
359
+ df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_abs_err
360
+ df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err
361
+ df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousandth_err_ratio
362
+ df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousandth_err_ratio
363
+ df.loc[process_index, CompareConst.NPU_MAX] = result.npu_max
364
+ df.loc[process_index, CompareConst.NPU_MIN] = result.npu_min
365
+ df.loc[process_index, CompareConst.NPU_MEAN] = result.npu_mean
366
+ df.loc[process_index, CompareConst.NPU_NORM] = result.npu_norm
367
+ df.loc[process_index, CompareConst.BENCH_MAX] = result.bench_max
368
+ df.loc[process_index, CompareConst.BENCH_MIN] = result.bench_min
369
+ df.loc[process_index, CompareConst.BENCH_MEAN] = result.bench_mean
370
+ df.loc[process_index, CompareConst.BENCH_NORM] = result.bench_norm
371
+ df.loc[process_index, CompareConst.ACCURACY] = result.accuracy
372
+ df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.error_message
373
+ df.loc[process_index, CompareConst.DATA_NAME] = str(result.data_name) # 列表转为字符串存储
374
+ return df
375
+ except ValueError as e:
376
+ logger.error('result dataframe is not found.')
377
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
378
+ except IndexError as e:
379
+ logger.error('result dataframe elements can not be access.')
380
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
381
+ finally:
382
+ lock.release()
@@ -13,41 +13,17 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
16
  from msprobe.core.common.utils import CompareException
18
17
  from msprobe.core.common.file_utils import create_directory
19
18
  from msprobe.core.common.exceptions import FileCheckException
20
19
  from msprobe.mindspore.common.log import logger
21
20
  from msprobe.mindspore.compare.ms_compare import ms_compare
22
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
21
+ from msprobe.core.compare.utils import compare_distributed_inner
23
22
  from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
24
23
 
25
24
 
26
25
  def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
27
- if kwargs.get('suffix'):
28
- logger.error("Argument 'suffix' is not supported for compare_distributed.")
29
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
30
- is_print_compare_log = kwargs.get('is_print_compare_log', True)
31
- # get the ranks and match by order
32
- npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
33
- bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
34
- if len(npu_ranks) != len(bench_ranks):
35
- logger.error('The number of ranks in the two runs are different. '
36
- 'Unable to match the ranks. Please use another folder to compare '
37
- 'or use compare() api and manually match the ranks.')
38
- raise CompareException(CompareException.INVALID_PATH_ERROR)
39
- for nr, br in zip(npu_ranks, bench_ranks):
40
- npu_data_dir = os.path.join(npu_dump_dir, nr)
41
- bench_data_dir = os.path.join(bench_dump_dir, br)
42
- npu_path = extract_json(npu_data_dir, stack_json=False)
43
- bench_path = extract_json(bench_data_dir, stack_json=False)
44
-
45
- dump_result_param = {
46
- 'npu_json_path': npu_path,
47
- 'bench_json_path': bench_path,
48
- 'is_print_compare_log': is_print_compare_log
49
- }
50
- ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
26
+ compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, ms_compare, **kwargs)
51
27
 
52
28
 
53
29
  def ms_graph_compare(inputs, outputs):