mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,26 +1,52 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import hashlib
1
17
  import zlib
2
18
  from dataclasses import asdict
3
19
  from typing import List
4
20
 
5
21
  import numpy as np
6
22
  import torch
7
- from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode
23
+ from torch import distributed as dist
24
+
25
+ from msprobe.core.common.const import Const
26
+ from msprobe.core.common.file_utils import path_len_exceeds_limit
8
27
  from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
28
+ from msprobe.core.common.utils import convert_tuple
10
29
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
11
30
  ModuleForwardInputsOutputs, TensorStatInfo
12
- from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
13
31
  from msprobe.pytorch.common.utils import save_pt, load_pt
32
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
33
+ from msprobe.core.common.utils import recursion_depth_decorator
14
34
 
35
+ is_gpu = False
15
36
  try:
16
37
  import torch_npu
17
- is_gpu = False
18
38
  except ImportError:
19
39
  is_gpu = True
20
40
 
21
41
 
22
42
  class PytorchDataProcessor(BaseDataProcessor):
23
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
43
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
44
+ memory_format = {
45
+ torch.contiguous_format: "contiguous_format",
46
+ torch.channels_last: "channels_last",
47
+ torch.channels_last_3d: "channels_last_3d",
48
+ torch.preserve_format: "preserve_format"
49
+ }
24
50
 
25
51
  def __init__(self, config, data_writer):
26
52
  super().__init__(config, data_writer)
@@ -64,8 +90,8 @@ class PytorchDataProcessor(BaseDataProcessor):
64
90
  if data_clone.numel() == 0:
65
91
  return tensor_stat
66
92
  elif data_clone.dtype == torch.bool:
67
- tensor_stat.max = True in data_clone
68
- tensor_stat.min = False not in data_clone
93
+ tensor_stat.max = torch._C._VariableFunctionsClass.any(data_clone).item()
94
+ tensor_stat.min = torch._C._VariableFunctionsClass.all(data_clone).item()
69
95
  elif not data_clone.shape:
70
96
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
71
97
  elif torch.is_complex(data_clone):
@@ -89,20 +115,46 @@ class PytorchDataProcessor(BaseDataProcessor):
89
115
  data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
90
116
  if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
91
117
  return float('nan')
118
+
92
119
  finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
93
120
  if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
94
- finite_values = data_clone[finite_mask]
121
+ finite_values = getattr(torch._C._TensorBase, "__getitem__")(data_clone, finite_mask)
95
122
  return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
96
123
  torch._C._VariableFunctionsClass.min(finite_values).item()
97
124
  else:
98
- data_no_nan = data_clone[~data_nan]
125
+ data_no_nan = getattr(torch._C._TensorBase, "__getitem__")(data_clone, ~data_nan)
99
126
  return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
100
127
  torch._C._VariableFunctionsClass.min(data_no_nan).item()
101
128
 
129
+ @staticmethod
130
+ def process_group_hash(arg):
131
+ group_ranks = dist.get_process_group_ranks(arg)
132
+ group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
133
+ return group_ranks_hash
134
+
102
135
  @staticmethod
103
136
  def _analyze_torch_size(arg):
104
137
  return {"type": "torch.Size", "value": list(arg)}
105
138
 
139
+ @staticmethod
140
+ def _analyze_memory_format(arg):
141
+ # 获取内存格式
142
+ format_type = PytorchDataProcessor.memory_format.get(arg)
143
+
144
+ return {"type": "torch.memory_format", "format": format_type}
145
+
146
+ @staticmethod
147
+ def _analyze_process_group(arg):
148
+ group_info = {"type": "torch.ProcessGroup"}
149
+ try:
150
+ group_ranks = dist.get_process_group_ranks(arg)
151
+ group_info.update({"group_ranks": group_ranks})
152
+ group_id = PytorchDataProcessor.process_group_hash(arg)
153
+ group_info.update({"group_id": group_id})
154
+ except Exception as e:
155
+ logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
156
+ return group_info
157
+
106
158
  @classmethod
107
159
  def get_special_types(cls):
108
160
  return super().get_special_types() + cls.pytorch_special_type
@@ -112,6 +164,10 @@ class PytorchDataProcessor(BaseDataProcessor):
112
164
  return self.torch_object_key[suffix_stack[-1]](element)
113
165
  if isinstance(element, torch.Size):
114
166
  return self._analyze_torch_size(element)
167
+ if isinstance(element, torch.memory_format):
168
+ return self._analyze_memory_format(element)
169
+ if isinstance(element, dist.ProcessGroup):
170
+ return self._analyze_process_group(element)
115
171
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
116
172
  if converted_numpy is not element:
117
173
  return self._analyze_numpy(converted_numpy, numpy_type)
@@ -153,7 +209,7 @@ class StatisticsDataProcessor(PytorchDataProcessor):
153
209
  class TensorDataProcessor(PytorchDataProcessor):
154
210
  def _analyze_tensor(self, tensor, suffix):
155
211
  dump_data_name, file_path = self.get_save_file_path(suffix)
156
- saved_tensor = tensor.contiguous().detach()
212
+ saved_tensor = tensor.clone().contiguous().detach()
157
213
  save_pt(saved_tensor, file_path)
158
214
  single_arg = super()._analyze_tensor(tensor, suffix)
159
215
  single_arg.update({"data_name": dump_data_name})
@@ -178,7 +234,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
178
234
  if self.overflow_nums == -1:
179
235
  return False
180
236
  if self.real_overflow_nums >= self.overflow_nums:
181
- logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
182
237
  return True
183
238
  return False
184
239
 
@@ -219,6 +274,9 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
219
274
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
220
275
  save_pt(tensor, file_path)
221
276
  self.real_overflow_nums += 1
277
+ if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
278
+ logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
279
+ f"current overflow times: {self.real_overflow_nums}.")
222
280
  self.cached_tensors_and_file_paths = {}
223
281
 
224
282
  def _is_support_inf_nan(self):
@@ -243,7 +301,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
243
301
  if tensor_json['Max'] is None or tensor_json['Min'] is None:
244
302
  return
245
303
  self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
246
- np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
304
+ np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
247
305
 
248
306
  def _analyze_tensor(self, tensor, suffix):
249
307
  dump_data_name, file_path = self.get_save_file_path(suffix)
@@ -303,64 +361,120 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
303
361
 
304
362
 
305
363
  class KernelDumpDataProcessor(PytorchDataProcessor):
306
- forward_init_status = False
307
- multi_output_apis = ["_sort_", "npu_flash_attention"]
308
-
309
364
  def __init__(self, config, data_writer):
310
365
  super().__init__(config, data_writer)
366
+ self.enable_kernel_dump = True
367
+ self.is_found_output_tensor = False
368
+ self.is_found_grad_input_tensor = False
369
+ self.forward_args = None
370
+ self.forward_kwargs = None
371
+ self.forward_output_tensor = None
372
+ self.grad_input_tensor = None
373
+
374
+ @staticmethod
375
+ def start_kernel_dump(config_path):
376
+ torch_npu.npu.synchronize()
377
+ torch_npu.npu.init_dump()
378
+ torch_npu.npu.set_dump(config_path)
379
+ torch_npu.npu.synchronize()
380
+
381
+ @staticmethod
382
+ def stop_kernel_dump():
383
+ torch_npu.npu.synchronize()
384
+ torch_npu.npu.finalize_dump()
385
+ torch_npu.npu.synchronize()
386
+
387
+ @staticmethod
388
+ def _print_unsupported_log(api_name):
389
+ logger.warning(f"The kernel dump does not support the {api_name} API.")
390
+
391
+ def analyze_pre_forward(self, name, module, module_input_output):
392
+ if not self.enable_kernel_dump:
393
+ return
394
+ if is_gpu:
395
+ logger.warning("The current environment is not a complete NPU environment, and kernel dump cannot be used.")
396
+ self.enable_kernel_dump = False
397
+ return
398
+
399
+ if self.config.is_backward_kernel_dump:
400
+ self.forward_args = self.clone_and_detach_tensor(module_input_output.args)
401
+ self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs)
402
+ try:
403
+ output = module.forward(*self.forward_args, **self.forward_kwargs)
404
+ except Exception:
405
+ self._print_unsupported_log(name)
406
+ self.enable_kernel_dump = False
407
+ return
408
+
409
+ self.analyze_element(convert_tuple(output))
410
+ if not self.is_found_output_tensor:
411
+ self._print_unsupported_log(name)
412
+ self.enable_kernel_dump = False
413
+ return
414
+ self.start_kernel_dump(self.config.kernel_config_path)
311
415
 
312
416
  def analyze_forward(self, name, module, module_input_output):
313
- if self.config.is_forward_acl_dump:
314
- self.forward_acl_dump(name, module, module_input_output)
417
+ if not self.enable_kernel_dump:
418
+ return
419
+ if self.config.is_backward_kernel_dump:
420
+ return
421
+ self.enable_kernel_dump = False
422
+ self.stop_kernel_dump()
423
+ logger.info(f"The kernel data of {name} is dumped successfully.")
424
+
425
+ def analyze_backward(self, name, module, module_input_output):
426
+ if not self.enable_kernel_dump:
427
+ return
428
+ self.enable_kernel_dump = False
429
+
430
+ self.analyze_element(module_input_output.grad_input)
431
+ if not self.is_found_grad_input_tensor:
432
+ self._print_unsupported_log(name)
433
+ return
434
+ self.start_kernel_dump(self.config.kernel_config_path)
435
+
436
+ try:
437
+ self.forward_output_tensor.backward(self.grad_input_tensor, retain_graph=True)
438
+ except Exception:
439
+ self._print_unsupported_log(name)
440
+ self.stop_kernel_dump()
441
+ return
442
+
443
+ self.stop_kernel_dump()
444
+ logger.info(f"The kernel data of {name} is dumped successfully.")
445
+
446
+ @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor")
447
+ def clone_and_detach_tensor(self, input_params):
448
+ if isinstance(input_params, torch.Tensor):
449
+ if input_params.requires_grad:
450
+ return input_params.clone().detach().requires_grad_()
451
+ return input_params.clone()
452
+ elif isinstance(input_params, tuple):
453
+ return tuple(self.clone_and_detach_tensor(x) for x in input_params)
454
+ elif isinstance(input_params, list):
455
+ return list(self.clone_and_detach_tensor(x) for x in input_params)
456
+ elif isinstance(input_params, dict):
457
+ return {k: self.clone_and_detach_tensor(v) for k, v in input_params.items()}
315
458
  else:
316
- self.dump_mode_backward_acl_dump(name, module, module_input_output)
317
-
318
- def forward_acl_dump(self, name, module, module_input_output):
319
- if not KernelDumpDataProcessor.forward_init_status:
320
- KernelDumpDataProcessor.forward_init_status = True
321
- torch_npu.npu.synchronize()
322
- torch_npu.npu.init_dump()
323
- torch_npu.npu.set_dump(self.config.acl_config)
324
- torch_npu.npu.synchronize()
325
- if self.op_need_trigger(name):
326
- module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
327
- else:
328
- module.forward(*module_input_output.args, **module_input_output.kwargs)
329
- torch_npu.npu.synchronize()
330
- torch_npu.npu.finalize_dump()
331
- torch_npu.npu.synchronize()
332
- KernelDumpDataProcessor.forward_init_status = False
333
- logger.info("Dump %s op file." % name)
334
-
335
- def acl_backward_dump_status(self, output, grad, module_name):
336
- if isinstance(output, torch.Tensor):
337
- output.backward(grad, retain_graph=True)
338
- return True
459
+ return input_params
339
460
 
340
- for api_name in KernelDumpDataProcessor.multi_output_apis:
341
- if api_name in module_name:
342
- output[0].backward(grad, retain_graph=True)
343
- return True
344
- return False
461
+ def analyze_single_element(self, element, suffix_stack):
462
+ if isinstance(element, torch.Tensor):
463
+ if not self.is_found_output_tensor:
464
+ if element.requires_grad:
465
+ self.forward_output_tensor = element
466
+ self.is_found_output_tensor = True
467
+ return {}
468
+ if not self.is_found_grad_input_tensor:
469
+ self.grad_input_tensor = element.clone()
470
+ self.is_found_grad_input_tensor = True
471
+ return {}
345
472
 
346
- def dump_mode_backward_acl_dump(self, name, module, module_input_output):
347
- grad_path = self.config.backward_input.get(name)
348
- if not KernelDumpDataProcessor.forward_init_status:
349
- KernelDumpDataProcessor.forward_init_status = True
350
- output = module.forward(*module_input_output.args, **module_input_output.kwargs)
351
- pt = load_pt(grad_path)
352
- grad = pt.to("npu").requires_grad_()
353
- torch_npu.npu.init_dump()
354
- torch_npu.npu.set_dump(self.config.acl_config)
355
- torch_npu.npu.synchronize()
356
- if not self.acl_backward_dump_status(output, grad, name):
357
- logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
358
- "you can manually construct a single API backward case for ACL dump.".format(
359
- name))
360
- torch_npu.npu.synchronize()
361
- torch_npu.npu.finalize_dump()
362
- KernelDumpDataProcessor.forward_init_status = False
363
- logger.info("Dump %s op file." % name)
364
-
365
- def op_need_trigger(self, module_name):
366
- return 'Tensor.__getitem__.' in module_name
473
+ def reset_status(self):
474
+ self.enable_kernel_dump = True
475
+ self.is_found_output_tensor = False
476
+ self.is_found_grad_input_tensor = False
477
+ self.forward_args = None
478
+ self.forward_kwargs = None
479
+ self.forward_output_tensor = None
480
+ self.grad_input_tensor = None
@@ -1,24 +1,36 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import csv
17
+ import os
3
18
 
4
- from msprobe.core.common.file_utils import change_mode, FileOpen
5
- from msprobe.core.common.log import logger
6
19
  from msprobe.core.common.const import Const, FileCheckConst
7
- from msprobe.core.common.file_utils import remove_path, load_json, save_json
20
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
21
+ from msprobe.core.common.log import logger
8
22
 
9
23
 
10
24
  class DataWriter:
11
25
 
12
- def __init__(self, init_json=None) -> None:
13
- self.dump_count = 0
14
- self.init_json = init_json
15
- self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
16
- self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
17
- self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
26
+ def __init__(self) -> None:
27
+ self.dump_file_path = None
28
+ self.stack_file_path = None
29
+ self.construct_file_path = None
18
30
  self.free_benchmark_file_path = None
19
31
  self.dump_tensor_data_dir = None
20
- self.buffer_size = 1000
21
- self.cache_data = {Const.DATA: {}}
32
+ self.flush_size = 1000
33
+ self.cache_data = {}
22
34
  self.cache_stack = {}
23
35
  self.cache_construct = {}
24
36
 
@@ -37,18 +49,22 @@ class DataWriter:
37
49
  if is_new_file:
38
50
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
39
51
 
40
- def initialize_json_file(self, **kwargs):
41
- kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
42
- save_json(self.dump_file_path, kwargs)
43
-
44
- empty_dict = {}
45
- remove_path(self.stack_file_path)
46
- save_json(self.stack_file_path, empty_dict)
47
-
48
- remove_path(self.construct_file_path)
49
- save_json(self.construct_file_path, empty_dict)
52
+ def reset_cache(self):
53
+ self.cache_data = {}
54
+ self.cache_stack = {}
55
+ self.cache_construct = {}
50
56
 
51
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
57
+ def initialize_json_file(self, **kwargs):
58
+ if not self.cache_data:
59
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
60
+ self.cache_data = kwargs
61
+ save_json(self.dump_file_path, self.cache_data, indent=1)
62
+ if not self.cache_stack:
63
+ save_json(self.stack_file_path, self.cache_stack, indent=1)
64
+ if not self.cache_construct:
65
+ save_json(self.construct_file_path, self.cache_construct, indent=1)
66
+
67
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
52
68
  free_benchmark_file_path):
53
69
  self.dump_file_path = dump_file_path
54
70
  self.stack_file_path = stack_file_path
@@ -56,16 +72,25 @@ class DataWriter:
56
72
  self.dump_tensor_data_dir = dump_data_dir
57
73
  self.free_benchmark_file_path = free_benchmark_file_path
58
74
 
75
+ def flush_data_periodically(self):
76
+ dump_data = self.cache_data.get(Const.DATA)
77
+ if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
78
+ self.write_json()
79
+
59
80
  def update_data(self, new_data):
60
- key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
61
- if key in self.cache_data[Const.DATA]:
62
- self.cache_data[Const.DATA][key].update(new_data[key])
63
- else:
64
- self.cache_data[Const.DATA].update(new_data)
81
+ if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
82
+ logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
83
+ return
84
+ dump_data = self.cache_data.get(Const.DATA)
85
+ if not isinstance(dump_data, dict):
86
+ logger.warning(f"The dump data({dump_data}) should be a dict.")
87
+ return
65
88
 
66
- def flush_data_when_buffer_is_full(self):
67
- if len(self.cache_data[Const.DATA]) >= self.buffer_size:
68
- self.write_data_json(self.dump_file_path)
89
+ key = next(iter(new_data.keys()))
90
+ if key in dump_data:
91
+ dump_data.get(key).update(new_data.get(key))
92
+ else:
93
+ dump_data.update(new_data)
69
94
 
70
95
  def update_stack(self, new_data):
71
96
  self.cache_stack.update(new_data)
@@ -75,14 +100,7 @@ class DataWriter:
75
100
 
76
101
  def write_data_json(self, file_path):
77
102
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
78
- if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
79
- data_to_write = load_json(file_path)
80
- else:
81
- self.init_json['data_path'] = self.dump_tensor_data_dir
82
- data_to_write = self.init_json
83
- data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
84
- save_json(file_path, data_to_write, indent=1)
85
- self.cache_data[Const.DATA].clear()
103
+ save_json(file_path, self.cache_data, indent=1)
86
104
 
87
105
  def write_stack_info_json(self, file_path):
88
106
  save_json(file_path, self.cache_stack, indent=1)
@@ -91,6 +109,9 @@ class DataWriter:
91
109
  save_json(file_path, self.cache_construct, indent=1)
92
110
 
93
111
  def write_json(self):
94
- self.write_data_json(self.dump_file_path)
95
- self.write_stack_info_json(self.stack_file_path)
96
- self.write_construct_info_json(self.construct_file_path)
112
+ if self.cache_data:
113
+ self.write_data_json(self.dump_file_path)
114
+ if self.cache_stack:
115
+ self.write_stack_info_json(self.stack_file_path)
116
+ if self.cache_construct:
117
+ self.write_construct_info_json(self.construct_file_path)