mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -25,15 +40,22 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
25
40
  x_shape = x.shape
26
41
  h = x.float()
27
42
  grad = dy_tensor.float()
28
- condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
29
- ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
30
- (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
31
- condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
32
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
33
- (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
34
- condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
35
- ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
36
- (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
43
+ if len(r1_shape) < 4 or len(x_shape) < 4:
44
+ raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
45
+ f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
46
+ condition_1 = (r1_shape[0] == 1
47
+ and r1_shape[1] == x_shape[1]
48
+ and r1_shape[2] == 1
49
+ and r1_shape[3] == x_shape[3])
50
+ condition_2 = (r1_shape[0] == 1
51
+ and r1_shape[1] == 1
52
+ and r1_shape[2] == x_shape[2]
53
+ and r1_shape[3] == x_shape[3])
54
+ condition_3 = (r1_shape[0] == x_shape[0]
55
+ and r1_shape[1] == 1
56
+ and r1_shape[2] == 1
57
+ and r1_shape[3] == x_shape[3])
58
+
37
59
  if condition_1:
38
60
  for i in range(x_shape[0]):
39
61
  for j in range(x_shape[2]):
@@ -49,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
49
71
  for j in range(x_shape[2]):
50
72
  r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
51
73
  r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
74
+
52
75
  return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,16 +1,35 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
4
19
  def npu_swiglu(x, dim=-1):
5
20
  tensor_dtype = x.dtype
6
21
 
7
- inTensors = torch.chunk(x, 2, dim=dim)
22
+ try:
23
+ in_tensors = torch.chunk(x, 2, dim=dim)
24
+ except Exception as e:
25
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
26
+
8
27
  if tensor_dtype == torch.float32:
9
- tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
10
- output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
28
+ tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
29
+ output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
11
30
  else:
12
- tensor_self_float = inTensors[0].type(torch.float)
13
- tensor_other_float = inTensors[1].type(torch.float)
31
+ tensor_self_float = in_tensors[0].type(torch.float)
32
+ tensor_other_float = in_tensors[1].type(torch.float)
14
33
  tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
15
34
  torch.float32) * tensor_other_float
16
35
  output_data = tensor_out_float.type(tensor_dtype)
@@ -19,7 +38,11 @@ def npu_swiglu(x, dim=-1):
19
38
 
20
39
  def npu_swiglu_backward(grad, x, dim=-1):
21
40
  tensor_dtype = grad.dtype
22
- in_tensors = torch.chunk(x, 2, dim=dim)
41
+ try:
42
+ in_tensors = torch.chunk(x, 2, dim=dim)
43
+ except Exception as e:
44
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
45
+
23
46
  tensor_grad_out = grad
24
47
 
25
48
  if tensor_dtype == torch.float16:
@@ -1,2 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from .parse_json import parse_json_info_forward_backward
2
17
  from .utils import seed_all
@@ -1,9 +1,21 @@
1
- import os
2
- import time
3
- import sys
4
- from msprobe.pytorch.common.utils import get_rank_if_initialized
5
- from msprobe.core.common.log import BaseLogger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
6
16
  from msprobe.core.common.exceptions import DistributedNotInitializedError
17
+ from msprobe.core.common.log import BaseLogger
18
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
7
19
 
8
20
 
9
21
  class PyTorchLogger(BaseLogger):
@@ -18,4 +30,4 @@ class PyTorchLogger(BaseLogger):
18
30
  return current_rank
19
31
 
20
32
 
21
- logger = PyTorchLogger()
33
+ logger = PyTorchLogger()
@@ -1,25 +1,32 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
2
15
 
3
16
  from msprobe.core.common.exceptions import ParseJsonException
4
- from msprobe.core.common.file_utils import FileOpen
17
+ from msprobe.core.common.file_utils import load_json
18
+ from msprobe.core.common.log import logger
5
19
 
6
20
 
7
21
  def parse_json_info_forward_backward(json_path):
8
- def parse_data_name_with_pattern(data_name, pattern):
9
- name_struct = data_name.split('.')
10
- if not name_struct[-1] == pattern:
11
- raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
12
- f"{data_name} in file {json_path}")
13
- api_name = '.'.join(name_struct[:-1])
14
- return api_name
15
-
16
- with FileOpen(json_path, 'r') as f:
17
- dump_json = json.load(f)
22
+ dump_json = load_json(json_path)
18
23
 
19
24
  real_data_path = dump_json.get("dump_data_dir")
20
25
  dump_data = dump_json.get("data")
26
+ if dump_data is None:
27
+ raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json")
21
28
  if not dump_data:
22
- raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段")
29
+ logger.warning("data field is empty, no overflow data found.")
23
30
 
24
31
  forward_data = {}
25
32
  backward_data = {}
@@ -27,13 +34,21 @@ def parse_json_info_forward_backward(json_path):
27
34
  if "Module" in data_name:
28
35
  continue
29
36
  if "forward" in data_name:
30
- api_name = parse_data_name_with_pattern(data_name, "forward")
37
+ api_name = parse_data_name_with_pattern(data_name, "forward", json_path)
31
38
  forward_data.update({api_name: data_item})
32
39
  elif "backward" in data_name:
33
- api_name = parse_data_name_with_pattern(data_name, "backward")
40
+ api_name = parse_data_name_with_pattern(data_name, "backward", json_path)
34
41
  backward_data.update({api_name: data_item})
35
42
  else:
36
43
  raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
37
- f"{data_name} in file {json_path}.")
44
+ f"{data_name} in file {json_path}.")
38
45
 
39
46
  return forward_data, backward_data, real_data_path
47
+
48
+
49
+ def parse_data_name_with_pattern(data_name, pattern, json_path):
50
+ name_struct = data_name.split('.')
51
+ if not name_struct[-1] == pattern:
52
+ raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, f"{data_name} in file {json_path}")
53
+ api_name = '.'.join(name_struct[:-1])
54
+ return api_name
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,20 +12,23 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import io
18
17
  import os
18
+ import pickle
19
19
  import random
20
20
  import stat
21
+ from functools import wraps
22
+
23
+ import numpy as np
21
24
  import torch
22
25
  import torch.distributed as dist
23
- import numpy as np
24
- from functools import wraps
25
26
  from msprobe.core.common.exceptions import DistributedNotInitializedError
26
- from msprobe.core.common.log import logger
27
27
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
28
- check_file_or_directory_path, check_path_before_create)
29
-
28
+ check_file_or_directory_path, check_path_before_create, FileOpen)
29
+ from msprobe.core.common.log import logger
30
+ from msprobe.core.common.utils import check_seed_all
31
+ from packaging import version
30
32
 
31
33
  try:
32
34
  import torch_npu
@@ -35,10 +37,8 @@ except ImportError:
35
37
  else:
36
38
  is_gpu = False
37
39
 
38
-
39
40
  torch_without_guard_version = torch.__version__ >= '2.1'
40
41
 
41
-
42
42
  if not is_gpu and not torch_without_guard_version:
43
43
  from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
44
44
 
@@ -46,7 +46,6 @@ npu_distributed_api = ['isend', 'irecv']
46
46
 
47
47
 
48
48
  def parameter_adapter(func):
49
-
50
49
  def handle_masked_select(input_tensor, indices):
51
50
  masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
52
51
  if input_tensor.dtype == torch.bfloat16:
@@ -77,20 +76,22 @@ def parameter_adapter(func):
77
76
  else:
78
77
  res = [input_tensor[tensor_index] for tensor_index in indices]
79
78
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
80
- if self.op_name_ == "__eq__" and args[1] is None:
79
+ if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
81
80
  return False
82
81
  return func(self, *args, **kwargs)
82
+
83
83
  return inner
84
84
 
85
85
 
86
86
  def torch_device_guard(func):
87
87
  if is_gpu or torch_without_guard_version:
88
88
  return func
89
- # Parse args/kwargs matched torch.device objects
90
89
 
90
+ # Parse args/kwargs matched torch.device objects
91
91
  @torch_npu_device_guard
92
92
  def wrapper(*args, **kwargs):
93
93
  return func(*args, **kwargs)
94
+
94
95
  return wrapper
95
96
 
96
97
 
@@ -105,20 +106,28 @@ def get_rank_if_initialized():
105
106
 
106
107
 
107
108
  def seed_all(seed=1234, mode=False):
108
- random.seed(seed)
109
- os.environ['PYTHONHASHSEED'] = str(seed)
110
- np.random.seed(seed)
111
- torch.manual_seed(seed)
112
- torch.use_deterministic_algorithms(mode)
113
- if is_gpu:
114
- torch.cuda.manual_seed_all(seed)
115
- torch.cuda.manual_seed(seed)
116
- torch.backends.cudnn.deterministic = True
117
- torch.backends.cudnn.enable = False
118
- torch.backends.cudnn.benchmark = False
119
- else:
120
- torch_npu.npu.manual_seed_all(seed)
121
- torch_npu.npu.manual_seed(seed)
109
+ check_seed_all(seed, mode)
110
+ try:
111
+ random.seed(seed)
112
+ os.environ['PYTHONHASHSEED'] = str(seed)
113
+ np.random.seed(seed)
114
+ torch.manual_seed(seed)
115
+ cuda_version = torch.version.cuda
116
+ if cuda_version is not None and version.parse(cuda_version) >= version.parse("10.2"):
117
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
118
+ os.environ['HCCL_DETERMINISTIC'] = str(mode)
119
+ torch.use_deterministic_algorithms(mode)
120
+ if is_gpu:
121
+ torch.cuda.manual_seed_all(seed)
122
+ torch.cuda.manual_seed(seed)
123
+ torch.backends.cudnn.deterministic = True
124
+ torch.backends.cudnn.enable = False
125
+ torch.backends.cudnn.benchmark = False
126
+ else:
127
+ torch_npu.npu.manual_seed_all(seed)
128
+ torch_npu.npu.manual_seed(seed)
129
+ except Exception as e:
130
+ logger.error(f"There is an unexpected error while determinating randomness. {e}")
122
131
 
123
132
 
124
133
  class Const:
@@ -191,10 +200,7 @@ class Const:
191
200
  ENV_ENABLE = "1"
192
201
  ENV_DISABLE = "0"
193
202
 
194
- MAX_SEED_VALUE = 2**32 - 1
195
-
196
- INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
197
- "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
203
+ MAX_SEED_VALUE = 2 ** 32 - 1
198
204
 
199
205
  TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
200
206
  LEVEL_LIST = ["L0", "L1", "L2", "mix"]
@@ -257,34 +263,84 @@ def print_rank_0(message):
257
263
  logger.info(message)
258
264
  else:
259
265
  logger.info(message)
260
-
266
+
261
267
 
262
268
  def load_pt(pt_path, to_cpu=False):
263
269
  pt_path = os.path.realpath(pt_path)
264
270
  check_file_or_directory_path(pt_path)
265
271
  try:
266
272
  if to_cpu:
267
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
273
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
268
274
  else:
269
- pt = torch.load(pt_path)
275
+ pt = torch.load(pt_path, weights_only=True)
270
276
  except Exception as e:
271
277
  raise RuntimeError(f"load pt file {pt_path} failed") from e
272
278
  return pt
273
279
 
274
280
 
275
281
  def save_pt(tensor, filepath):
276
- filepath = os.path.realpath(filepath)
277
282
  check_path_before_create(filepath)
283
+ filepath = os.path.realpath(filepath)
278
284
  try:
279
285
  torch.save(tensor, filepath)
280
286
  except Exception as e:
281
287
  logger.error("Save pt file failed, please check according possible error causes: "
282
- "1. out of disk space or disk error, "
283
- "2. no permission to write files, etc.")
288
+ "1. out of disk space or disk error, "
289
+ "2. no permission to write files, etc.")
284
290
  raise RuntimeError(f"save pt file {filepath} failed") from e
285
291
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
286
292
 
287
293
 
294
+ class TypeCheckingUnpickler(pickle.Unpickler):
295
+ """
296
+ This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
297
+ It overrides the find_class method to add type checking functionality.
298
+ """
299
+ allowed_types = [
300
+ "str",
301
+ "ApiData",
302
+ "OrderedDict",
303
+ "_rebuild_tensor_v2", # from torch.utils
304
+ "_load_from_bytes" # from torch.storage
305
+ ]
306
+
307
+ def find_class(self, module, name):
308
+ """
309
+ Method to find the class of the object to be unpickled.
310
+ Throws pickle.UnpicklingError If the object type is not in the allowed types list.
311
+ """
312
+ if name in self.allowed_types:
313
+ return super().find_class(module, name)
314
+ raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
315
+
316
+
317
+ def save_pkl(tensor, filepath):
318
+ """Save ApiData or str objection by pickle"""
319
+ check_path_before_create(filepath)
320
+ filepath = os.path.realpath(filepath)
321
+ try:
322
+ with FileOpen(filepath, 'wb') as f:
323
+ pickle.dump(tensor, f)
324
+ except Exception as e:
325
+ logger.error("Save pt file failed, please check according possible error causes: "
326
+ "1. out of disk space or disk error, "
327
+ "2. no permission to write files, etc.")
328
+ raise RuntimeError(f"save pt file {filepath} failed") from e
329
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
330
+
331
+
332
+ def load_pkl(pt_path):
333
+ """Load ApiData or str objection by pickle for accuracy_checker_online"""
334
+ check_file_or_directory_path(pt_path)
335
+ pt_path = os.path.realpath(pt_path)
336
+ try:
337
+ with FileOpen(pt_path, 'rb') as f:
338
+ pt = TypeCheckingUnpickler(f).load()
339
+ except Exception as e:
340
+ raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
341
+ return pt
342
+
343
+
288
344
  def save_api_data(api_data):
289
345
  """Save data to io stream"""
290
346
  try:
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,14 +12,13 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import os
18
17
  from msprobe.core.common.utils import CompareException, check_compare_param, \
19
- check_configuration_param, task_dumppath_get
18
+ check_configuration_param, set_dump_path, get_dump_mode
20
19
  from msprobe.core.common.file_utils import create_directory
21
20
  from msprobe.core.common.exceptions import FileCheckException
22
21
  from msprobe.pytorch.common.log import logger
23
- from msprobe.core.common.const import Const
24
22
  from msprobe.pytorch.compare.pt_compare import PTComparator
25
23
  from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
26
24
 
@@ -32,6 +30,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
32
30
  stack_mode = kwargs.get('stack_mode', False)
33
31
  auto_analyze = kwargs.get('auto_analyze', True)
34
32
  fuzzy_match = kwargs.get('fuzzy_match', False)
33
+ is_print_compare_log = kwargs.get('is_print_compare_log', True)
35
34
  # get the ranks and match by order
36
35
  npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
37
36
  bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
@@ -51,16 +50,16 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
51
50
  'npu_json_path': npu_path,
52
51
  'bench_json_path': bench_path,
53
52
  'stack_json_path': stack_path,
54
- 'is_print_compare_log': True
53
+ 'is_print_compare_log': is_print_compare_log
55
54
  }
56
55
  try:
57
- summary_compare, md5_compare = task_dumppath_get(dump_result_param)
58
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
56
+ set_dump_path(dump_result_param)
57
+ dump_mode = get_dump_mode(dump_result_param)
58
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
59
59
  create_directory(output_path)
60
- check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
60
+ check_compare_param(dump_result_param, output_path, dump_mode)
61
61
  except (CompareException, FileCheckException) as error:
62
62
  logger.error('Compare failed. Please check the arguments and do it again!')
63
63
  raise CompareException(error.code) from error
64
64
  pt_comparator = PTComparator()
65
- pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
66
- md5_compare=md5_compare, **kwargs)
65
+ pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from msprobe.core.common.utils import CompareException
3
18
  from msprobe.core.common.file_utils import load_yaml
@@ -1,19 +1,52 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os.path
2
17
  import torch
3
18
  from msprobe.core.common.const import FileCheckConst
4
19
  from msprobe.pytorch.common.log import logger
5
20
  from msprobe.core.common.exceptions import FileCheckException
6
21
  from msprobe.core.compare.acc_compare import Comparator
7
- from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException
8
- from msprobe.core.common.file_utils import FileChecker, create_directory
22
+ from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
23
+ CompareException, set_dump_path, get_dump_mode
24
+ from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
9
25
  from msprobe.pytorch.common.utils import load_pt
10
26
 
11
27
 
12
28
  class PTComparator (Comparator):
13
- def __init__(self):
29
+ def __init__(self, data_mapping=None):
14
30
  self.frame_name = PTComparator.__name__
31
+ self.data_mapping = data_mapping
32
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
33
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
34
+ elif isinstance(self.data_mapping, dict):
35
+ self.data_mapping_dict = self.data_mapping
36
+ else:
37
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
+ f"{type(self.data_mapping)}")
39
+
40
+ def load_mapping_file(self, mapping_file):
41
+ if isinstance(mapping_file, str):
42
+ mapping_dict = load_yaml(mapping_file)
43
+ else:
44
+ mapping_dict = {}
45
+ return mapping_dict
15
46
 
16
47
  def read_npy_data(self, dir_path, file_name):
48
+ if not file_name:
49
+ return None
17
50
  data_path = os.path.join(dir_path, file_name)
18
51
  path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
52
  FileCheckConst.PT_SUFFIX, False)
@@ -35,16 +68,17 @@ class PTComparator (Comparator):
35
68
  return data_value
36
69
 
37
70
 
38
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
71
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
39
72
  try:
40
- summary_compare, md5_compare = task_dumppath_get(input_param)
41
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
73
+ set_dump_path(input_param)
74
+ dump_mode = get_dump_mode(input_param)
75
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
42
76
  create_directory(output_path)
43
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
77
+ check_compare_param(input_param, output_path, dump_mode)
78
+ data_mapping = kwargs.get('data_mapping', None)
44
79
  except (CompareException, FileCheckException) as error:
45
80
  logger.error('Compare failed. Please check the arguments and do it again!')
46
81
  raise CompareException(error.code) from error
47
- pt_comparator = PTComparator()
82
+ pt_comparator = PTComparator(data_mapping)
48
83
  pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
49
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
50
- md5_compare=md5_compare)
84
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)