mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,9 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import atexit
1
17
  import os
2
18
 
3
- from msprobe.core.data_dump.scope import build_scope, ListScope
19
+ from msprobe.core.data_dump.scope import ScopeFactory
4
20
  from msprobe.core.data_dump.json_writer import DataWriter
5
21
  from msprobe.core.common.log import logger
6
- from msprobe.core.common.const import Const, MsgConst
22
+ from msprobe.core.common.const import Const
7
23
  from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
8
24
 
9
25
 
@@ -12,24 +28,17 @@ def build_data_collector(config):
12
28
 
13
29
 
14
30
  class DataCollector:
15
- multi_output_apis = ["_sort_", "npu_flash_attention"]
16
31
  tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
17
- level_without_construct = ["L1", "L2"]
32
+ level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
18
33
 
19
34
  def __init__(self, config):
20
35
  self.config = config
21
36
  self.data_writer = DataWriter()
22
37
  self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
23
- self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
24
- if self.config.framework == Const.PT_FRAMEWORK else None
38
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
25
39
  self.module_count = {}
26
- if self.config.task == Const.FREE_BENCHMARK:
27
- self.scope = build_scope(ListScope, self.config.scope, self.config.list)
28
- else:
29
- self.scope = build_scope(None, self.config.scope, self.config.list)
30
-
31
- def __del__(self):
32
- self.write_json()
40
+ self.scope = ScopeFactory(self.config).build_scope()
41
+ atexit.register(self.write_json)
33
42
 
34
43
  @property
35
44
  def dump_data_dir(self):
@@ -59,18 +68,22 @@ class DataCollector:
59
68
  def write_json(self):
60
69
  self.data_writer.write_json()
61
70
 
62
- def update_data(self, data_info, msg=''):
71
+ def update_data(self, name, data_info):
72
+ msg = f"msprobe is collecting data on {name}."
63
73
  if self.config.task == Const.OVERFLOW_CHECK:
64
74
  if self.data_processor.has_overflow:
75
+ msg += " Overflow detected."
76
+ logger.warning(msg)
65
77
  self.data_writer.update_data(data_info)
66
- msg += "Overflow detected."
67
- else:
68
- msg += "No Overflow, OK."
69
- else:
70
- self.data_writer.update_data(data_info)
71
- return msg
78
+ return
79
+ logger.debug(msg)
80
+ self.data_writer.update_data(data_info)
72
81
 
73
82
  def pre_forward_data_collect(self, name, module, pid, module_input_output):
83
+ if self.config.level == Const.LEVEL_L2 and self.check_scope_and_pid(self.scope, name, pid):
84
+ self.data_processor.analyze_pre_forward(name, module, module_input_output)
85
+ return
86
+
74
87
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
75
88
  if self.check_scope_and_pid(self.scope, backward_name, pid):
76
89
  self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
@@ -78,27 +91,22 @@ class DataCollector:
78
91
  return
79
92
  logger.info(f"API {name} is inplace.")
80
93
  data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
81
- self.handle_data(name, data_info)
94
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
82
95
 
83
96
  def forward_data_collect(self, name, module, pid, module_input_output):
84
97
  self.update_construct(name)
85
98
  if not self.check_scope_and_pid(self.scope, name, pid):
86
99
  return
100
+ if self.config.level == Const.LEVEL_L2:
101
+ self.data_processor.analyze_forward(name, module, module_input_output)
102
+ return
87
103
 
88
104
  if not self.is_inplace(module):
89
105
  data_info = self.data_processor.analyze_forward(name, module, module_input_output)
90
106
  else:
91
107
  data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
92
- if self.config.level == "L2":
93
- return
94
108
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
95
- if self.config.framework == Const.MS_FRAMEWORK:
96
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
97
- else:
98
- if self.data_processor.is_terminated:
99
- self.handle_data(name, data_info, flush=True)
100
- raise Exception(f"[{Const.TOOL_NAME}] exit")
101
- self.handle_data(name, data_info)
109
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
102
110
 
103
111
  def backward_data_collect(self, name, module, pid, module_input_output):
104
112
  self.update_construct(name)
@@ -106,13 +114,9 @@ class DataCollector:
106
114
  return
107
115
 
108
116
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
109
- if self.config.framework == Const.MS_FRAMEWORK:
110
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
111
- else:
112
- if self.data_processor.is_terminated:
113
- self.handle_data(name, data_info, flush=True)
114
- raise Exception(f"[{Const.TOOL_NAME}] exit")
115
- self.handle_data(name, data_info)
117
+ if self.config.level == Const.LEVEL_L2:
118
+ return
119
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
116
120
 
117
121
  def backward_input_data_collect(self, name, module, pid, module_input_output):
118
122
  self.update_construct(name)
@@ -131,18 +135,15 @@ class DataCollector:
131
135
  self.handle_data(name, data_info)
132
136
 
133
137
  def update_construct(self, name):
134
- if self.config.framework == Const.PT_FRAMEWORK and \
135
- self.config.level not in DataCollector.level_without_construct:
138
+ if self.config.level not in DataCollector.level_without_construct:
136
139
  self.data_writer.update_construct({name: self.module_processor.api_parent_node})
137
140
  self.data_writer.update_construct(self.module_processor.module_node)
138
141
 
139
142
  def handle_data(self, name, data_info, flush=False):
140
143
  if data_info:
141
- msg = f"msprobe is collecting data on {name}. "
142
- msg = self.update_data(data_info, msg)
143
- logger.info(MsgConst.CLEAR_SYMBOL + msg, end='\r')
144
+ self.update_data(name, data_info)
144
145
  if not flush:
145
- self.data_writer.flush_data_when_buffer_is_full()
146
+ self.data_writer.flush_data_periodically()
146
147
  else:
147
148
  self.write_json()
148
149
 
@@ -1,11 +1,28 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import inspect
3
- from dataclasses import dataclass
17
+ import os
18
+ from dataclasses import dataclass, is_dataclass
4
19
  from typing import Tuple, Dict, Optional, Any
20
+
5
21
  import numpy as np
6
- from msprobe.core.common.log import logger
7
- from msprobe.core.common.utils import convert_tuple
22
+
8
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.log import logger
25
+ from msprobe.core.common.utils import convert_tuple, CompareException
9
26
 
10
27
 
11
28
  @dataclass
@@ -69,8 +86,11 @@ class TensorStatInfo:
69
86
 
70
87
  class BaseDataProcessor:
71
88
  _recursive_key_stack = []
72
- special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
73
- bool, int, float, str, slice, type(Ellipsis))
89
+ special_type = (
90
+ np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
91
+ bool, int, float, str, slice,
92
+ type(Ellipsis)
93
+ )
74
94
 
75
95
  def __init__(self, config, data_writer):
76
96
  self.data_writer = data_writer
@@ -82,30 +102,33 @@ class BaseDataProcessor:
82
102
  self.current_iter = 0
83
103
  self._return_forward_new_output = False
84
104
  self._forward_new_output = None
105
+ if hasattr(config, "data_mode"):
106
+ self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
85
107
 
86
108
  @property
87
109
  def data_path(self):
88
110
  return self.data_writer.dump_tensor_data_dir
89
-
111
+
90
112
  @property
91
113
  def is_terminated(self):
92
114
  return False
93
115
 
94
116
  @staticmethod
95
117
  def analyze_api_call_stack(name):
118
+ try:
119
+ api_stack = inspect.stack()[5:]
120
+ except Exception as e:
121
+ logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
122
+ api_stack = None
96
123
  stack_str = []
97
- for (_, path, line, func, code, _) in inspect.stack()[5:]:
98
- if not code:
99
- continue
100
- stack_line = " ".join([
101
- "File", ", ".join([
102
- path,
103
- " ".join(["line", str(line)]),
104
- " ".join(["in", func]),
105
- " ".join(["\n", code[0].strip()])
106
- ])
107
- ])
108
- stack_str.append(stack_line)
124
+ if api_stack:
125
+ for (_, path, line, func, code, _) in api_stack:
126
+ if not code:
127
+ continue
128
+ stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
129
+ stack_str.append(stack_line)
130
+ else:
131
+ stack_str.append(Const.WITHOUT_CALL_STACK)
109
132
  stack_info_struct = {name: stack_str}
110
133
  return stack_info_struct
111
134
 
@@ -162,34 +185,66 @@ class BaseDataProcessor:
162
185
  def _analyze_numpy(value, numpy_type):
163
186
  return {"type": numpy_type, "value": value}
164
187
 
188
+ @staticmethod
189
+ def _get_allowed_data_mode(data_mode):
190
+ if Const.ALL in data_mode:
191
+ allowed_data_mode = [Const.FORWARD, Const.BACKWARD, Const.INPUT, Const.OUTPUT]
192
+ else:
193
+ allowed_data_mode = list(set(data_mode))
194
+ if Const.FORWARD not in allowed_data_mode and Const.BACKWARD not in allowed_data_mode:
195
+ allowed_data_mode += [Const.FORWARD, Const.BACKWARD]
196
+ if Const.INPUT not in allowed_data_mode and Const.OUTPUT not in allowed_data_mode:
197
+ allowed_data_mode += [Const.INPUT, Const.OUTPUT]
198
+ return allowed_data_mode
199
+
165
200
  @classmethod
166
201
  def get_special_types(cls):
167
202
  return cls.special_type
168
203
 
169
204
  @classmethod
170
- def recursive_apply_transform(cls, args, transform):
205
+ def recursive_apply_transform(cls, args, transform, depth=0):
206
+ if depth > Const.MAX_DEPTH:
207
+ logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
208
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
171
209
  if isinstance(args, cls.get_special_types()):
172
210
  arg_transform = transform(args, cls._recursive_key_stack)
173
211
  return arg_transform
212
+ elif isinstance(args, tuple) and hasattr(args, '_fields'):
213
+ # namedtuple to dict
214
+ args_dict = {field: getattr(args, field) for field in args._fields}
215
+ return cls.apply_transform_dict(args_dict, transform, depth)
216
+ elif is_dataclass(args):
217
+ # dataclass to dict
218
+ args_dict = {field: getattr(args, field) for field in args.__dataclass_fields__}
219
+ return cls.apply_transform_dict(args_dict, transform, depth)
174
220
  elif isinstance(args, (list, tuple)):
175
- result_list = []
176
- for i, arg in enumerate(args):
177
- cls._recursive_key_stack.append(str(i))
178
- result_list.append(cls.recursive_apply_transform(arg, transform))
179
- cls._recursive_key_stack.pop()
221
+ result_list = cls.apply_transform_list(args, transform, depth)
180
222
  return type(args)(result_list)
181
223
  elif isinstance(args, dict):
182
- result_dict = {}
183
- for k, arg in args.items():
184
- cls._recursive_key_stack.append(str(k))
185
- result_dict[k] = cls.recursive_apply_transform(arg, transform)
186
- cls._recursive_key_stack.pop()
187
- return result_dict
224
+ return cls.apply_transform_dict(args, transform, depth)
188
225
  elif args is not None:
189
226
  logger.warning(f"Data type {type(args)} is not supported.")
190
227
  return None
191
228
  else:
192
229
  return None
230
+
231
+ @classmethod
232
+ def apply_transform_dict(cls, args, transform, depth):
233
+ result_dict = {}
234
+ for k, arg in args.items():
235
+ cls._recursive_key_stack.append(str(k))
236
+ result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
237
+ cls._recursive_key_stack.pop()
238
+ return result_dict
239
+
240
+ @classmethod
241
+ def apply_transform_list(cls, args, transform, depth):
242
+ result_list = []
243
+ for i, arg in enumerate(args):
244
+ cls._recursive_key_stack.append(str(i))
245
+ result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
246
+ cls._recursive_key_stack.pop()
247
+ return result_list
193
248
 
194
249
  def if_return_forward_new_output(self):
195
250
  return self._return_forward_new_output
@@ -216,13 +271,11 @@ class BaseDataProcessor:
216
271
  Return:
217
272
  bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
218
273
  """
219
- return (Const.ALL in self.config.data_mode or
220
- forward_backward in self.config.data_mode or
221
- input_output in self.config.data_mode)
274
+ return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
222
275
 
223
276
  def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
224
277
  pass
225
-
278
+
226
279
  def analyze_element(self, element):
227
280
  return self.recursive_apply_transform(element, self.analyze_single_element)
228
281
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
 
3
18
 
@@ -34,14 +49,14 @@ class DataProcessorFactory:
34
49
  @classmethod
35
50
  def register_processors(cls, framework):
36
51
  if framework == Const.PT_FRAMEWORK:
37
- from .pytorch_processor import (
52
+ from msprobe.core.data_dump.data_processor.pytorch_processor import (
38
53
  StatisticsDataProcessor as PytorchStatisticsDataProcessor,
39
54
  TensorDataProcessor as PytorchTensorDataProcessor,
40
55
  OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
41
56
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
42
57
  KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
43
58
  )
44
- from ....pytorch.module_processer import ModuleProcesser
59
+ from msprobe.pytorch.module_processer import ModuleProcesser
45
60
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
46
61
  cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
47
62
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
@@ -49,11 +64,13 @@ class DataProcessorFactory:
49
64
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
50
65
  cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
51
66
  elif framework == Const.MS_FRAMEWORK:
52
- from .mindspore_processor import (
67
+ from msprobe.core.data_dump.data_processor.mindspore_processor import (
53
68
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
54
69
  TensorDataProcessor as MindsporeTensorDataProcessor,
55
70
  OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
56
71
  )
72
+ from msprobe.mindspore.cell_processor import CellProcessor
57
73
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
58
74
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
59
75
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
76
+ cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -17,6 +17,7 @@ import zlib
17
17
 
18
18
  import mindspore as ms
19
19
  from mindspore import mint, ops
20
+ from mindspore._c_expression.typing import Number
20
21
  import numpy as np
21
22
 
22
23
  from msprobe.core.common.const import Const
@@ -29,7 +30,7 @@ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
29
30
 
30
31
 
31
32
  class MindsporeDataProcessor(BaseDataProcessor):
32
- mindspore_special_type = tuple([ms.Tensor])
33
+ mindspore_special_type = tuple([ms.Tensor, Number])
33
34
 
34
35
  def __init__(self, config, data_writer):
35
36
  super().__init__(config, data_writer)
@@ -40,7 +41,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
40
41
  @staticmethod
41
42
  def get_md5_for_tensor(x):
42
43
  x = convert_bf16_to_fp32(x)
43
- tensor_bytes = x.asnumpy().tobytes()
44
+ tensor_bytes = x.contiguous().asnumpy().tobytes()
44
45
  crc32_hash = zlib.crc32(tensor_bytes)
45
46
  return f"{crc32_hash:08x}"
46
47
 
@@ -57,25 +58,28 @@ class MindsporeDataProcessor(BaseDataProcessor):
57
58
  if data.numel() == 0:
58
59
  return tensor_stat
59
60
  elif data.dtype == ms.bool_:
60
- data_np = data.asnumpy()
61
+ data_np = data.contiguous().asnumpy()
61
62
  tensor_stat.max = np.max(data_np).item()
62
63
  tensor_stat.min = np.min(data_np).item()
63
64
  elif not data.shape:
64
65
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
65
66
  elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
66
- data_abs = np.abs(data.asnumpy())
67
+ data_abs = np.abs(data.contiguous().asnumpy())
67
68
  tensor_stat.max = np.max(data_abs).item()
68
69
  tensor_stat.min = np.min(data_abs).item()
69
70
  tensor_stat.mean = np.mean(data_abs).item()
70
71
  tensor_stat.norm = np.linalg.norm(data_abs).item()
71
72
  else:
72
- if data.dtype == ms.bfloat16 or not ops.is_floating_point(data):
73
+ if not ops.is_floating_point(data) or data.dtype == ms.float64:
73
74
  data = data.to(ms.float32)
74
75
  api_register.norm_inner_op_set_ori_func()
75
76
  get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
76
77
  get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
77
78
  get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
78
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
79
+ if hasattr(mint, "norm"):
80
+ get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
81
+ else:
82
+ get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
79
83
  tensor_stat.max = get_max_value(data).item()
80
84
  tensor_stat.min = get_min_value(data).item()
81
85
  tensor_stat.mean = get_mean_value(data).item()
@@ -90,9 +94,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
90
94
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
91
95
  if converted_numpy is not element:
92
96
  return self._analyze_numpy(converted_numpy, numpy_type)
97
+ if isinstance(element, Number):
98
+ return self.analyze_dtype_in_kwargs(element)
93
99
  if isinstance(element, ms.Tensor):
94
100
  return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
95
-
96
101
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
97
102
  return self._analyze_builtin(element)
98
103
  return {}
@@ -163,7 +168,8 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
163
168
  save_tensor_as_npy(tensor, file_path)
164
169
  self.real_overflow_nums += 1
165
170
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
166
- logger.info(f"[{Const.TOOL_NAME}] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
171
+ logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
172
+ f"current overflow times: {self.real_overflow_nums}.")
167
173
  self.cached_tensors_and_file_paths = {}
168
174
 
169
175
  def _analyze_maybe_overflow_tensor(self, tensor_json):