mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -14,117 +14,46 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.log import logger
17
- from msprobe.core.compare.utils import rename_api
18
17
  from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
19
- from msprobe.core.common.const import CompareConst, Const
20
-
21
- dtype_mapping = {
22
- "Int8": "torch.int8",
23
- "UInt8": "torch.uint8",
24
- "Int16": "torch.int16",
25
- "UInt16": "torch.uint16",
26
- "Int32": "torch.int32",
27
- "UInt32": "torch.uint32",
28
- "Int64": "torch.int64",
29
- "UInt64": "torch.uint64",
30
- "Float16": "torch.float16",
31
- "Float32": "torch.float32",
32
- "Float64": "torch.float64",
33
- "Bool": "torch.bool",
34
- "BFloat16": "torch.bfloat16",
35
- "Complex64": "torch.complex64",
36
- "Complex128": "torch.complex128"
18
+ from msprobe.core.common.const import Const
19
+
20
+ cross_dtype_mapping = {
21
+ "Int8": "int",
22
+ "torch.int8": "int",
23
+ "UInt8": "int",
24
+ "torch.uint8": "int",
25
+ "Int16": "int",
26
+ "torch.int16": "int",
27
+ "UInt16": "int",
28
+ "torch.uint16": "int",
29
+ "Int32": "int",
30
+ "torch.int32": "int",
31
+ "UInt32": "int",
32
+ "torch.uint32": "int",
33
+ "Int64": "int",
34
+ "torch.int64": "int",
35
+ "UInt64": "int",
36
+ "torch.uint64": "int",
37
+
38
+ "Float16": "float",
39
+ "torch.float16": "float",
40
+ "Float32": "float",
41
+ "torch.float32": "float",
42
+ "Float64": "float",
43
+ "torch.float64": "float",
44
+ "BFloat16": "float",
45
+ "torch.bfloat16": "float",
46
+
47
+ "Bool": "bool",
48
+ "torch.bool": "bool",
49
+
50
+ "Complex64": "complex",
51
+ "torch.complex64": "complex",
52
+ "Complex128": "complex",
53
+ "torch.complex128": "complex",
37
54
  }
38
55
 
39
56
 
40
- def compare_op_dict_struct(npu_dict, bench_dict):
41
- return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
42
-
43
-
44
- def check_struct_match(npu_dict, bench_dict):
45
- is_match = compare_op_dict_struct(npu_dict, bench_dict)
46
- if not is_match:
47
- struct_match_list = []
48
- try:
49
- for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
50
- # 首先额外检查input_struct是否空,input_struct不可能为空
51
- if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
52
- return False
53
- struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
54
- except CompareException as error:
55
- err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
56
- f'npu_dict: {npu_dict}' \
57
- f'bench_dict: {bench_dict}'
58
- logger.error(err_msg)
59
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
60
- is_match = all(struct_match_list)
61
- return is_match
62
-
63
-
64
- def check_type_shape_match(npu_struct, bench_struct):
65
- """
66
- further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
67
- """
68
- if len(npu_struct) != len(bench_struct):
69
- return False
70
- if not npu_struct and not bench_struct:
71
- return True
72
-
73
- struct_match = False
74
- for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
75
- try:
76
- npu_type = npu_type_shape[0]
77
- npu_shape = npu_type_shape[1]
78
- bench_type = bench_type_shape[0]
79
- bench_shape = bench_type_shape[1]
80
- except IndexError as error:
81
- logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
82
- f'should both be 2, please check!')
83
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
84
- shape_match = npu_shape == bench_shape
85
- type_match = npu_type == bench_type
86
- if not type_match:
87
- if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE):
88
- type_match = True
89
- else:
90
- type_match = False
91
- struct_match = shape_match and type_match
92
- if not struct_match:
93
- return False
94
- return struct_match
95
-
96
-
97
- def check_graph_mode(a_op_name, b_op_name):
98
- if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
99
- return True
100
- if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
101
- return True
102
- return False
103
-
104
-
105
- def fuzzy_check_op(npu_name_list, bench_name_list):
106
- # 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
107
- # 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
108
- if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
109
- return False
110
- is_match = True
111
- for npu_name, bench_name in zip(npu_name_list, bench_name_list):
112
- is_match = fuzzy_check_name(npu_name, bench_name)
113
- if not is_match:
114
- break
115
- return is_match
116
-
117
-
118
- def fuzzy_check_name(npu_name, bench_name):
119
- if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
120
- is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
121
- elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
122
- is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
123
- else:
124
- is_match = npu_name == bench_name
125
- return is_match
126
-
127
-
128
57
  def check_dump_json_str(op_data, op_name):
129
58
  input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
130
59
  Const.INPUT, None)
@@ -38,6 +38,7 @@ def compare_cli(args):
38
38
  else:
39
39
  from msprobe.mindspore.compare.ms_compare import ms_compare
40
40
  from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
41
+ from msprobe.mindspore.compare.common_dir_compare import common_dir_compare
41
42
 
42
43
  common_kwargs = {
43
44
  "auto_analyze": auto_analyze,
@@ -78,6 +79,9 @@ def compare_cli(args):
78
79
  if input_param.get("rank_id") is not None:
79
80
  ms_graph_compare(input_param, args.output_path)
80
81
  return
82
+ if input_param.get('common', False):
83
+ common_dir_compare(input_param, args.output_path)
84
+ return
81
85
  if frame_name == Const.PT_FRAMEWORK:
82
86
  compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
83
87
  else:
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import Const, CompareConst
19
+ from msprobe.core.common.file_utils import load_yaml
20
+
21
+
22
+ class ModeConfig:
23
+ def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY,
24
+ compared_file_type=Const.DUMP_JSON_FILE):
25
+ self.stack_mode = stack_mode
26
+ self.auto_analyze = auto_analyze
27
+ self.fuzzy_match = fuzzy_match
28
+ self.dump_mode = dump_mode
29
+ self.compared_file_type = compared_file_type
30
+
31
+
32
+ class MappingConfig:
33
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
34
+ self.cell_mapping = cell_mapping
35
+ self.api_mapping = api_mapping
36
+ self.data_mapping = data_mapping
37
+
38
+
39
+ class MappingDict:
40
+ def __init__(self, mapping_config: MappingConfig):
41
+ self.cell_mapping_dict = self.load_mapping_file(mapping_config.cell_mapping)
42
+ self.api_mapping_dict = self.load_mapping_file(mapping_config.api_mapping)
43
+ if mapping_config.api_mapping is not None:
44
+ self.ms_to_pt_mapping = self.load_internal_api()
45
+ self.data_mapping_dict = self.init_data_mapping(mapping_config.data_mapping)
46
+
47
+ @staticmethod
48
+ def load_internal_api():
49
+ cur_path = os.path.dirname(os.path.realpath(__file__))
50
+ yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
51
+ return load_yaml(yaml_path)
52
+
53
+ @staticmethod
54
+ def load_mapping_file(mapping_file):
55
+ if isinstance(mapping_file, str):
56
+ mapping_dict = load_yaml(mapping_file)
57
+ else:
58
+ mapping_dict = {}
59
+ return mapping_dict
60
+
61
+ def init_data_mapping(self, data_mapping):
62
+ """
63
+ 初始化data_mapping_dict
64
+ """
65
+ if isinstance(data_mapping, str) or data_mapping is None:
66
+ data_mapping_dict = self.load_mapping_file(data_mapping)
67
+ elif isinstance(data_mapping, dict):
68
+ data_mapping_dict = data_mapping
69
+ else:
70
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
71
+ f"{type(data_mapping)}")
72
+ return data_mapping_dict