mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -16,7 +16,6 @@
16
16
  import glob
17
17
  import os.path
18
18
  import time
19
- import re
20
19
  from multiprocessing import Queue
21
20
  from typing import Optional, Union, Dict, Any
22
21
  from dataclasses import dataclass
@@ -26,9 +25,8 @@ import torch
26
25
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
27
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
28
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
29
- from msprobe.pytorch.common.utils import logger
30
28
  from msprobe.core.common.file_utils import remove_path
31
- from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
29
+ from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
32
30
 
33
31
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
34
32
 
@@ -55,7 +53,6 @@ class ATTL:
55
53
  self.dequeue_list = []
56
54
  self.message_end = False
57
55
  self.kill_progress = False
58
- self.check_attl_config()
59
56
  self.nfs_path = None
60
57
  if self.session_config.nfs_path:
61
58
  self.nfs_path = self.session_config.nfs_path
@@ -73,18 +70,6 @@ class ATTL:
73
70
  self.session_config.tls_path)
74
71
  self.socket_manager.start()
75
72
 
76
- def check_attl_config(self):
77
- if self.session_config.nfs_path:
78
- if os.path.exists(self.session_config.nfs_path):
79
- return
80
- else:
81
- raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
82
- ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
83
- if not re.match(ipv4_pattern, self.session_config.connect_ip):
84
- raise Exception(f"host {self.session_config.connect_ip} is invalid.")
85
- if not (0 < self.session_config.connect_port <= 65535):
86
- raise Exception(f"port {self.session_config.connect_port} is invalid.")
87
-
88
73
  def stop_serve(self):
89
74
  if isinstance(self.socket_manager, TCPServer):
90
75
  self.socket_manager.stop()
@@ -115,21 +100,21 @@ class ATTL:
115
100
  self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
116
101
 
117
102
  def recv(self, timeout_ms=0) -> Optional[BufferType]:
118
- buffer = None
119
- while buffer is None:
103
+ buffer = ''
104
+ while not buffer:
120
105
  if timeout_ms > 0:
121
106
  time.sleep(timeout_ms / 1000.0)
122
- if buffer is None and not self.data_queue.empty():
107
+ if not buffer and not self.data_queue.empty():
123
108
  buffer = self.data_queue.get()
124
109
  break
125
- if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
110
+ if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
126
111
  break
127
112
  if self.message_end and self.data_queue.empty():
128
113
  buffer = b"KILL_CONFIRM"
129
114
  self.kill_progress = True
130
115
  break
131
116
  time.sleep(0.1) # waiting outside the lock before next attempt
132
- if buffer is None:
117
+ if not buffer:
133
118
  # this is a result of a timeout
134
119
  self.logger.info(f"RECEIVE API DATA TIMED OUT")
135
120
  else:
@@ -146,7 +131,7 @@ class ATTL:
146
131
  except Exception as e:
147
132
  self.logger.warning("there is something error. please check it. %s", e)
148
133
  if isinstance(buffer, bytes):
149
- return None
134
+ return ''
150
135
  if isinstance(buffer, str):
151
136
  return buffer
152
137
 
@@ -160,7 +145,7 @@ class ATTL:
160
145
  file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
161
146
 
162
147
  try:
163
- save_pt(buffer, file_path)
148
+ save_pkl(buffer, file_path)
164
149
  except Exception as e:
165
150
  self.logger.warning("there is something error in save_pt. please check it. %s", e)
166
151
 
@@ -176,7 +161,7 @@ class ATTL:
176
161
 
177
162
  if cur_file is not None:
178
163
  try:
179
- buffer = load_pt(cur_file)
164
+ buffer = load_pkl(cur_file)
180
165
  except Exception as e:
181
166
  self.logger.warning("there is something error. please check it. %s", e)
182
167
  remove_path(cur_file)
@@ -27,8 +27,8 @@ from twisted.internet import reactor, protocol, endpoints
27
27
  from twisted.protocols.basic import FileSender
28
28
 
29
29
  from msprobe.pytorch.common.utils import logger
30
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import struct_unpack_mode as unpack_mode, \
31
- str_to_bytes_order as bytes_order
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
31
+ STR_TO_BYTES_ORDER as bytes_order
32
32
 
33
33
  MAX_SENDING_QUEUE_SIZE = 20
34
34
 
@@ -84,15 +84,6 @@ class TCPClient:
84
84
  def run_reactor():
85
85
  reactor.run(installSignalHandlers=False)
86
86
 
87
- def check_tls_path(self):
88
- client_key = os.path.join(self.tls_path, "client.key")
89
- client_crt = os.path.join(self.tls_path, "client.crt")
90
- if not os.path.exists(client_key):
91
- raise Exception(f"client_key: {client_key} is not exists.")
92
- if not os.path.exists(client_crt):
93
- raise Exception(f"client_crt: {client_crt} is not exists.")
94
- return client_key, client_crt
95
-
96
87
  def start(self):
97
88
  def conn_callback(cur_protocol):
98
89
  if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
@@ -114,7 +105,8 @@ class TCPClient:
114
105
  self.factory.protocol = cur_protocol
115
106
  if self.tls_path:
116
107
  from twisted.internet import ssl
117
- client_key, client_crt = self.check_tls_path()
108
+ client_key = os.path.join(self.tls_path, "client.key")
109
+ client_crt = os.path.join(self.tls_path, "client.crt")
118
110
  client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
119
111
  endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
120
112
  else:
@@ -24,7 +24,7 @@ from msprobe.core.common.const import Const, CompareConst
24
24
  from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
25
25
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
26
26
  binary_standard_api, absolute_standard_api
27
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
27
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
28
28
  from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
@@ -92,8 +92,10 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
92
92
 
93
93
  try:
94
94
  # NPU vs CPU
95
- cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
- cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
95
+ cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
+ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
97
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
98
+ cpu_out = exec_api(cpu_exec_params)
97
99
  npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
98
100
  npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
99
101
  npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
@@ -1,3 +1,4 @@
1
+
1
2
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
3
  # All rights reserved.
3
4
  #
@@ -14,6 +15,7 @@
14
15
  # limitations under the License.
15
16
 
16
17
  import os
18
+ from collections import defaultdict
17
19
  from functools import wraps
18
20
 
19
21
  import torch
@@ -39,7 +41,7 @@ def singleton(cls):
39
41
  @singleton
40
42
  class Counter:
41
43
  def __init__(self) -> None:
42
- self.index_dict = {}
44
+ self.index_dict = defaultdict(int)
43
45
 
44
46
 
45
47
  counter = Counter()
@@ -67,9 +69,9 @@ class AccuracyCheckerDispatch(TorchDispatchMode):
67
69
 
68
70
  res = func(*args, **kwargs)
69
71
  cur_rank = get_tensor_rank(args, res)
70
- cur_api_number = self.counter.index_dict.setdefault(aten_api, 0)
72
+ cur_api_number = self.counter.index_dict[aten_api]
71
73
  api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
72
- logger.info(f"tools is dumping api: {api_name}")
74
+ logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
73
75
  api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
74
76
  if "device" in api_data.kwargs:
75
77
  api_data.kwargs.pop("device")
@@ -98,7 +100,7 @@ def dispatch4data(func, attl, status):
98
100
  return wrapper
99
101
 
100
102
 
101
- def run_ut_dispatch(attl, status):
103
+ def run_ut_dispatch(attl, status, is_recompute=False):
102
104
  """
103
105
  This function called by online_run_ut.
104
106
  It is used to enable or disable dispatch for torch.autograd.backward function.
@@ -106,5 +108,8 @@ def run_ut_dispatch(attl, status):
106
108
  Args:
107
109
  attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
108
110
  status (bool): True means enable dispatch, False means disable dispatch.
111
+ is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
109
112
  """
113
+ if is_recompute:
114
+ return
110
115
  torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
@@ -24,7 +24,7 @@ from twisted.internet import reactor, protocol, endpoints
24
24
 
25
25
  from msprobe.pytorch.common.utils import logger
26
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
- struct_unpack_mode as unpack_mode, str_to_bytes_order as bytes_order
27
+ STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
28
28
 
29
29
 
30
30
  class TCPServer:
@@ -40,22 +40,14 @@ class TCPServer:
40
40
  def run_reactor():
41
41
  reactor.run(installSignalHandlers=False)
42
42
 
43
- def check_tls_path(self):
44
- server_key = os.path.join(self.tls_path, "server.key")
45
- server_crt = os.path.join(self.tls_path, "server.crt")
46
- if not os.path.exists(server_key):
47
- raise Exception(f"server_key: {server_key} is not exists.")
48
- if not os.path.exists(server_crt):
49
- raise Exception(f"server_crt: {server_crt} is not exists.")
50
- return server_key, server_crt
51
-
52
43
  def start(self):
53
44
  self.factory.protocol = self.build_protocol
54
45
 
55
46
  if self.tls_path:
56
47
  from OpenSSL import SSL
57
48
  from twisted.internet import ssl
58
- server_key, server_crt = self.check_tls_path()
49
+ server_key = os.path.join(self.tls_path, "server.key")
50
+ server_crt = os.path.join(self.tls_path, "server.crt")
59
51
  server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
60
52
  server_context_ = server_context_factory.getContext()
61
53
  server_context_.set_cipher_list(cipher_list)
@@ -40,5 +40,5 @@ cipher_list = ":".join(
40
40
  "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
41
41
  ).encode()
42
42
 
43
- struct_unpack_mode = "!Q"
44
- str_to_bytes_order = "big"
43
+ STRUCT_UNPACK_MODE = "!Q"
44
+ STR_TO_BYTES_ORDER = "big"
@@ -22,7 +22,11 @@ def npu_confusion_transpose(data, perm, shape, transpose_first):
22
22
 
23
23
 
24
24
  def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
25
- shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
25
+ try:
26
+ shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
27
+ except IndexError as e:
28
+ raise IndexError("npu_confusion_transpose_backward: Invalid perm index for shape") from e
29
+
26
30
  perm_cal = [0] * len(perm)
27
31
  for i, perm_dim in enumerate(perm):
28
32
  perm_cal[perm_dim] = i
@@ -17,6 +17,9 @@ import torch
17
17
 
18
18
 
19
19
  def matmul_backward(grad, self, other, mask):
20
+ if len(mask) < 2:
21
+ raise RuntimeError("Mask size at least 2")
22
+
20
23
  grad_self, grad_other = None, None
21
24
  dim_self = self.dim()
22
25
  dim_other = other.dim()
@@ -24,6 +27,7 @@ def matmul_backward(grad, self, other, mask):
24
27
  size_grad = list(grad.size())
25
28
  size_self = list(self.size())
26
29
  size_other = list(other.size())
30
+
27
31
  if dim_self == 1 and dim_other == 1:
28
32
  grad_self = other.mul(grad) if mask[0] else grad_self
29
33
  grad_other = self.mul(grad) if mask[1] else grad_other
@@ -34,19 +38,27 @@ def matmul_backward(grad, self, other, mask):
34
38
  grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
35
39
  grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
36
40
  elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
41
+ if len(size_grad) < 1:
42
+ raise RuntimeError("size_grad's length at least 1")
37
43
  view_size = 1 if dim_other == 1 else size_grad[-1]
38
44
  unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
39
45
  if mask[0]:
40
46
  grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
41
47
  .view(size_self)
42
48
  if mask[1]:
49
+ if len(size_self) < 1:
50
+ raise RuntimeError("size_self's length at least 1")
43
51
  unfolded_self = self.contiguous().view([-1, size_self[-1]])
44
52
  grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
45
53
  elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
54
+ if len(size_grad) < 2:
55
+ raise RuntimeError("size_grad's length at least 2")
46
56
  view_size = 1 if dim_self == 1 else size_grad[-2]
47
57
  unfolded_grad_t = grad.view([-1, view_size]) \
48
58
  if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
49
59
  if mask[0]:
60
+ if len(size_other) < 2:
61
+ raise RuntimeError("size_other's length at least 2")
50
62
  # create a 2D-matrix from other
51
63
  unfolded_other_t = \
52
64
  other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
@@ -30,6 +30,7 @@
30
30
  numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
31
31
  """
32
32
 
33
+ from collections import namedtuple
33
34
  import torch
34
35
  import numpy as np
35
36
  from einops import rearrange
@@ -50,8 +51,16 @@ else:
50
51
  from msprobe.pytorch.common.utils import logger
51
52
  from msprobe.core.common.const import Const, CompareConst
52
53
 
53
- gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
54
- softmax_build_mode = "QKV" # "MAX_SUM"
54
+ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
55
+ SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
56
+
57
+
58
+ FaForwardParams = namedtuple("FaForwardParams",
59
+ ["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"])
60
+ FaBackwardParams = namedtuple("FaBackwardParams",
61
+ ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"])
62
+ RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
63
+ ["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"])
55
64
 
56
65
 
57
66
  def softmax_forward(x):
@@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale):
99
108
  return qk
100
109
 
101
110
 
102
- def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
111
+ def fusion_attention_forward(forward_params):
112
+ q = forward_params.q
113
+ k = forward_params.k
114
+ v = forward_params.v
115
+ drop_mask = forward_params.drop_mask
116
+ atten_mask = forward_params.atten_mask
117
+ pse = forward_params.pse
118
+ scale = forward_params.scale
119
+ keep_prob = forward_params.keep_prob
103
120
  qk = calculate_qk(q, k, atten_mask, pse, scale)
104
121
  softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
105
122
  if drop_mask is None or len(drop_mask.shape) == 0:
@@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr
110
127
  return y, softmax_max, softmax_sum
111
128
 
112
129
 
113
- def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
130
+ def fusion_attention_backward(backward_params):
131
+ dx = backward_params.dx
132
+ q = backward_params.q
133
+ k = backward_params.k
134
+ v = backward_params.v
135
+ softmax_res = backward_params.softmax_res
136
+ drop_mask = backward_params.drop_mask
137
+ pse = backward_params.pse
138
+ scale = backward_params.scale
139
+ keep_prob = backward_params.keep_prob
114
140
  dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
115
141
  if drop_mask is None or len(drop_mask.shape) == 0:
116
142
  drop_res = softmax_res.permute(0, 1, 3, 2)
@@ -166,6 +192,18 @@ def parse_bsnd_args(query, key, head_num, input_layout):
166
192
 
167
193
 
168
194
  def convert_from_bnsd(_input, input_layout):
195
+ """
196
+ transform qkv from bnsd to input_layout.
197
+ B: batch_size
198
+ S: sequence_length
199
+ N: num_heads
200
+ D: head_dim
201
+ Args:
202
+ _input (torch.Tensor): tensor of shape (B,N,S,D)
203
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
204
+ Returns:
205
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
206
+ """
169
207
  if input_layout == "BSH":
170
208
  # (B,N,S,D)=>(B,S,N*D)
171
209
  out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
@@ -183,7 +221,19 @@ def convert_from_bnsd(_input, input_layout):
183
221
 
184
222
 
185
223
  def convert_to_bnsd(_input, n, input_layout):
186
- # 默认"BNSD"无需处理
224
+ """
225
+ transform qkv from input_layout to bnsd.
226
+ B: batch_size
227
+ S: sequence_length
228
+ N: num_heads
229
+ D: head_dim
230
+ Args:
231
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
232
+ n (int): num_heads
233
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
234
+ Returns:
235
+ tensor of shape (B,N,S,D)
236
+ """
187
237
  if input_layout == "BSH":
188
238
  # (B,S,N*D)=>(B,N,S,D)
189
239
  out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
@@ -199,7 +249,68 @@ def convert_to_bnsd(_input, n, input_layout):
199
249
  out = _input
200
250
  if out.dim() != 4:
201
251
  raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
202
- return out.to(gtype)
252
+ return out.to(GTYPE)
253
+
254
+
255
+ def convert_from_bsnd(_input, input_layout):
256
+ """
257
+ transform qkv from bsnd to input_layout.
258
+ B: batch_size
259
+ S: sequence_length
260
+ N: num_heads
261
+ D: head_dim
262
+ Args:
263
+ _input (torch.Tensor): tensor of shape (B,S,N,D)
264
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
265
+ Returns:
266
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
267
+ """
268
+ if input_layout == "BSH":
269
+ # (B,S,N,D)=>(B,S,N*D)
270
+ out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
271
+ elif input_layout == "SBH":
272
+ # (B,S,N,D)=>(S,B,N*D)
273
+ out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
274
+ elif input_layout == "BNSD":
275
+ # (B,S,N,D)=>(B,N,S,D)
276
+ out = rearrange(_input, 'b s n d -> b n s d').contiguous()
277
+ elif input_layout == "TND":
278
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
279
+ else:
280
+ out = _input
281
+ return out
282
+
283
+
284
+ def convert_to_bsnd(_input, n, input_layout):
285
+ """
286
+ transform qkv from input_layout to bsnd.
287
+ B: batch_size
288
+ S: sequence_length
289
+ N: num_heads
290
+ D: head_dim
291
+ Args:
292
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
293
+ n (int): num_heads
294
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
295
+ Returns:
296
+ tensor of shape (B,S,N,D)
297
+ """
298
+ if input_layout == "BSH":
299
+ # (B,S,N*D)=>(B,S,N,D)
300
+ out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
301
+ elif input_layout == "SBH":
302
+ # (S,B,N*D)=>(B,S,N,D)
303
+ out = rearrange(_input, 's b (n d) -> b s n d', n=n)
304
+ elif input_layout == "BNSD":
305
+ # (B,N,S,D)=>(B,S,N,D)
306
+ out = rearrange(_input, 'b n s d -> b s n d', n=n)
307
+ elif input_layout == "TND":
308
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
309
+ else:
310
+ out = _input
311
+ if out.dim() != 4:
312
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
313
+ return out
203
314
 
204
315
 
205
316
  def generate_atten_mask(*args):
@@ -279,15 +390,22 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
279
390
  """
280
391
  logger.info("Using QKV to rebuild original softmax")
281
392
  qk = calculate_qk(q, k, atten_mask, pse, scale)
282
- softmax_res, x_max, x_sum = softmax_forward(qk)
393
+ softmax_res, _, _ = softmax_forward(qk)
283
394
  return softmax_res
284
395
 
285
396
 
286
- def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
397
+ def rebuild_softmax_by_max_sum(softmax_params):
287
398
  """
288
399
  attention = softmax(QK^T/sqrt(d))V
289
400
  softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
290
401
  """
402
+ q = softmax_params.q
403
+ k = softmax_params.k
404
+ atten_mask = softmax_params.atten_mask
405
+ pse = softmax_params.pse
406
+ scale = softmax_params.scale
407
+ softmax_max = softmax_params.softmax_max
408
+ softmax_sum = softmax_params.softmax_sum
291
409
  logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
292
410
  qk = calculate_qk(q, k, atten_mask, pse, scale)
293
411
  if softmax_max.shape[-1] == 0:
@@ -319,6 +437,10 @@ def get_input_layout(*args, **kwargs):
319
437
 
320
438
 
321
439
  def npu_fusion_attention_forward_patch(*args, **kwargs):
440
+
441
+ if len(args) < 2:
442
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
443
+
322
444
  # query, key, value, head_num, input_layout
323
445
  head_num = get_head_num(*args, **kwargs)
324
446
  input_layout = get_input_layout(*args, **kwargs)
@@ -413,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs):
413
535
  key = convert_to_bnsd(key, n2, input_layout)
414
536
  value = convert_to_bnsd(value, n2, input_layout)
415
537
  k_new, v_new = generate_kv(key, value, n1, n2)
416
- out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
417
- drop_mask=None, atten_mask=atten_mask,
418
- pse=pse, scale=scale,
419
- keep_prob=keep_prob)
538
+ forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob)
539
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
420
540
  if out_golden.dim() == 5:
421
541
  out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
422
542
  out_golden.size(4))
@@ -454,12 +574,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
454
574
  value = convert_to_bnsd(value, n2, input_layout)
455
575
  k_new, v_new = generate_kv(key, value, n1, n2)
456
576
 
457
- if softmax_build_mode == "QKV":
577
+ if SOFTMAX_BUILD_MODE == "QKV":
458
578
  softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
459
579
  else:
460
- softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
461
-
462
- dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
580
+ softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
581
+ softmax_res = rebuild_softmax_by_max_sum(softmax_params)
582
+ backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
583
+ dq, dk, dv = fusion_attention_backward(backward_params)
463
584
 
464
585
  # N不等长适配by cdy
465
586
  if not (n1 == n2):
@@ -531,8 +652,13 @@ def gpu_fusion_attention(*args, **kwargs):
531
652
  else:
532
653
  alibi_slopes = None
533
654
 
655
+ input_layout = get_input_layout(*args, **kwargs)
656
+ query = convert_to_bsnd(query, n1, input_layout)
657
+ key = convert_to_bsnd(key, n2, input_layout)
658
+ value = convert_to_bsnd(value, n2, input_layout)
534
659
  out = flash_attn_func(
535
660
  query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
536
661
  window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
537
662
  )
663
+ out = convert_from_bsnd(out, input_layout)
538
664
  return out, Const.NONE, Const.NONE
@@ -40,6 +40,9 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
40
40
  x_shape = x.shape
41
41
  h = x.float()
42
42
  grad = dy_tensor.float()
43
+ if len(r1_shape) < 4 or len(x_shape) < 4:
44
+ raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
45
+ f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
43
46
  condition_1 = (r1_shape[0] == 1
44
47
  and r1_shape[1] == x_shape[1]
45
48
  and r1_shape[2] == 1
@@ -68,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
68
71
  for j in range(x_shape[2]):
69
72
  r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
70
73
  r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
74
+
71
75
  return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -19,7 +19,11 @@ import torch
19
19
  def npu_swiglu(x, dim=-1):
20
20
  tensor_dtype = x.dtype
21
21
 
22
- in_tensors = torch.chunk(x, 2, dim=dim)
22
+ try:
23
+ in_tensors = torch.chunk(x, 2, dim=dim)
24
+ except Exception as e:
25
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
26
+
23
27
  if tensor_dtype == torch.float32:
24
28
  tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
25
29
  output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
@@ -34,7 +38,11 @@ def npu_swiglu(x, dim=-1):
34
38
 
35
39
  def npu_swiglu_backward(grad, x, dim=-1):
36
40
  tensor_dtype = grad.dtype
37
- in_tensors = torch.chunk(x, 2, dim=dim)
41
+ try:
42
+ in_tensors = torch.chunk(x, 2, dim=dim)
43
+ except Exception as e:
44
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
45
+
38
46
  tensor_grad_out = grad
39
47
 
40
48
  if tensor_dtype == torch.float16:
@@ -13,20 +13,21 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
-
18
16
  from msprobe.core.common.exceptions import ParseJsonException
19
- from msprobe.core.common.file_utils import FileOpen
17
+ from msprobe.core.common.file_utils import load_json
18
+ from msprobe.core.common.log import logger
20
19
 
21
20
 
22
21
  def parse_json_info_forward_backward(json_path):
23
- with FileOpen(json_path, 'r') as f:
24
- dump_json = json.load(f)
22
+ dump_json = load_json(json_path)
25
23
 
26
24
  real_data_path = dump_json.get("dump_data_dir")
27
25
  dump_data = dump_json.get("data")
26
+ if dump_data is None:
27
+ raise ParseJsonException(ParseJsonException.InvalidDumpJson,
28
+ "something wrong with dump, no data found in dump.json")
28
29
  if not dump_data:
29
- raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段")
30
+ logger.warning("data field is empty, no overflow data found.")
30
31
 
31
32
  forward_data = {}
32
33
  backward_data = {}