mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,22 +1,42 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import functools
2
17
  import os
3
-
4
18
  from collections import namedtuple
19
+
5
20
  import torch
6
21
  from msprobe.core.common.const import Const
7
22
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
8
23
  from msprobe.core.common.file_utils import create_directory
24
+ from msprobe.core.common.utils import print_tools_ends_info
9
25
  from msprobe.core.data_dump.data_collector import build_data_collector
10
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
11
27
  from msprobe.core.data_dump.scope import BaseScope
28
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
12
29
  from msprobe.pytorch.common.log import logger
13
30
  from msprobe.pytorch.common.utils import get_rank_if_initialized
31
+ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
14
32
  from msprobe.pytorch.hook_module import remove_dropout
15
33
  from msprobe.pytorch.hook_module.api_registry import api_register
16
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
17
35
  from msprobe.pytorch.module_processer import ModuleProcesser
18
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
36
+
19
37
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
38
+ if torch_version_above_or_equal_2:
39
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
20
40
 
21
41
  HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
22
42
 
@@ -32,6 +52,7 @@ class Service:
32
52
  self.first_start = True
33
53
  self.current_rank = None
34
54
  self.dump_iter_dir = None
55
+ self.should_stop_service = False
35
56
  self.attl = None
36
57
 
37
58
  @staticmethod
@@ -39,14 +60,29 @@ class Service:
39
60
  logger.info_on_rank_0("Data needed ends here.")
40
61
  api_register.api_originality()
41
62
 
63
+ @staticmethod
64
+ def is_registered_backward_hook(module):
65
+ if hasattr(module, '_backward_hooks') and \
66
+ len(module._backward_hooks) > 0 and \
67
+ module._is_full_backward_hook is False:
68
+ return True
69
+ return False
70
+
71
+ def check_register_full_backward_hook(self, module):
72
+ if self.is_registered_backward_hook(module):
73
+ module._backward_hooks.clear()
74
+ module._is_full_backward_hook = None
75
+ logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
76
+
42
77
  def build_hook(self, module_type, name):
43
78
  def pre_hook(api_or_module_name, module, args, kwargs):
79
+ if not self.should_execute_hook():
80
+ return args, kwargs
81
+
44
82
  if module_type == BaseScope.Module_Type_Module:
45
83
  api_or_module_name = module.mindstudio_reserved_name
46
84
  self.data_collector.update_api_or_module_name(api_or_module_name)
47
85
 
48
- if not self.switch:
49
- return args, kwargs
50
86
  if self.config.online_run_ut:
51
87
  return None, None
52
88
  if self.data_collector:
@@ -55,13 +91,13 @@ class Service:
55
91
  return args, kwargs
56
92
 
57
93
  def forward_hook(api_or_module_name, module, args, kwargs, output):
94
+ if not self.should_execute_hook():
95
+ return None
96
+
58
97
  if module_type == BaseScope.Module_Type_Module:
59
98
  api_or_module_name = module.mindstudio_reserved_name
60
99
  self.data_collector.update_api_or_module_name(api_or_module_name)
61
100
 
62
- if not self.switch:
63
- return None
64
-
65
101
  if self.config.online_run_ut:
66
102
  if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
67
103
  return None
@@ -80,18 +116,14 @@ class Service:
80
116
  return forward_hook(api_or_module_name, module, args, {}, output)
81
117
 
82
118
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
119
+ if not self.should_execute_hook():
120
+ return
121
+
83
122
  if module_type == BaseScope.Module_Type_Module:
84
123
  api_or_module_name = module.mindstudio_reserved_name
85
124
  self.data_collector.update_api_or_module_name(api_or_module_name)
86
125
 
87
- if not self.switch:
88
- return
89
-
90
126
  if self.config.online_run_ut:
91
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
92
- return
93
- api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank)
94
- self.attl_send(api_data)
95
127
  return
96
128
 
97
129
  if self.data_collector:
@@ -105,26 +137,15 @@ class Service:
105
137
  pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
106
138
  forward_hook_fn = functools.partial(forward_hook, forward_name_template)
107
139
  backward_hook_fn = functools.partial(backward_hook, backward_name_template)
108
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template)
140
+ forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
141
+ forward_name_template)
109
142
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
110
143
 
111
- def step(self):
112
- self.current_iter += 1
113
- self.data_collector.update_iter(self.current_iter)
114
-
115
- ModuleProcesser.reset_module_stats()
116
- HOOKModule.reset_module_stats()
117
-
118
144
  def start(self, model, api_origin=False):
119
- self.model = model
120
- if self.config.step and self.current_iter > max(self.config.step):
121
- if self.config.online_run_ut:
122
- # send stop signal if online_run_ut
123
- self.attl_stop()
124
- self.stop()
125
- raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
126
- if self.config.step and self.current_iter not in self.config.step:
145
+ if self.need_stop_service():
127
146
  return
147
+
148
+ self.model = model
128
149
  if self.first_start:
129
150
  try:
130
151
  self.current_rank = get_rank_if_initialized()
@@ -138,13 +159,17 @@ class Service:
138
159
  self.first_start = False
139
160
  if api_origin:
140
161
  api_register.api_modularity()
162
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
163
+ run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
141
164
  self.switch = True
142
165
  logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
143
- if self.config.level != "L2" and not self.config.online_run_ut:
166
+ if not self.config.online_run_ut:
144
167
  self.create_dirs()
145
168
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
146
169
 
147
170
  def stop(self):
171
+ if self.should_stop_service:
172
+ return
148
173
  if self.config.level == "L2":
149
174
  return
150
175
  if self.config.step and self.current_iter not in self.config.step:
@@ -152,14 +177,60 @@ class Service:
152
177
  if self.config.rank and self.current_rank not in self.config.rank:
153
178
  return
154
179
  self.switch = False
155
- if self.config.online_run_ut:
180
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
181
+ run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
156
182
  return
157
183
  self.data_collector.write_json()
158
184
 
185
+ def step(self):
186
+ if self.should_stop_service:
187
+ return
188
+ self.current_iter += 1
189
+ self.data_collector.update_iter(self.current_iter)
190
+
191
+ ModuleProcesser.reset_module_stats()
192
+ HOOKModule.reset_module_stats()
193
+ self.data_collector.data_writer.reset_cache()
194
+
195
+ if self.config.level == Const.LEVEL_L2:
196
+ self.data_collector.data_processor.reset_status()
197
+
198
+ def need_stop_service(self):
199
+ if self.should_stop_service:
200
+ return True
201
+ end_service = self.config.step and self.current_iter > max(self.config.step) or \
202
+ self.data_collector and self.data_collector.data_processor.is_terminated
203
+ if end_service:
204
+ if self.config.online_run_ut:
205
+ # send stop signal if online_run_ut
206
+ self.attl_stop()
207
+ if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
208
+ api_register.api_originality()
209
+ self.switch = False
210
+ self.should_stop_service = True
211
+ print_tools_ends_info()
212
+ return True
213
+ if self.config.step and self.current_iter not in self.config.step:
214
+ return True
215
+ return False
216
+
217
+ def should_execute_hook(self):
218
+ if not self.switch:
219
+ return False
220
+ if self.data_collector and self.data_collector.data_processor.is_terminated:
221
+ return False
222
+ return True
223
+
159
224
  def create_dirs(self):
160
225
  create_directory(self.config.dump_path)
161
226
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
162
227
  cur_rank = self.current_rank if self.current_rank is not None else ''
228
+ if self.config.level == Const.LEVEL_L2:
229
+ create_directory(self.dump_iter_dir)
230
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
231
+ self.config.kernel_config_path = kernel_config_path
232
+ return
233
+
163
234
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
164
235
  create_directory(dump_dir)
165
236
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -187,14 +258,16 @@ class Service:
187
258
  prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
188
259
  module.__class__.__name__ + Const.SEP
189
260
 
190
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 \
191
- = self.build_hook(BaseScope.Module_Type_Module, prefix)
261
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
262
+ BaseScope.Module_Type_Module, prefix)
192
263
  if torch_version_above_or_equal_2:
193
264
  module.register_forward_hook(forward_hook, with_kwargs=True)
194
265
  else:
266
+ self.check_register_full_backward_hook(module)
195
267
  module.register_full_backward_hook(
196
268
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
197
269
  module.register_forward_hook(forward_hook_torch_version_below_2)
270
+ self.check_register_full_backward_hook(module)
198
271
  module.register_full_backward_hook(backward_hook)
199
272
 
200
273
  module.register_forward_pre_hook(
@@ -204,11 +277,13 @@ class Service:
204
277
  if torch_version_above_or_equal_2:
205
278
  module.register_full_backward_pre_hook(
206
279
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
280
+ self.check_register_full_backward_hook(module)
207
281
  module.register_full_backward_hook(
208
282
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
209
283
 
210
284
  if self.config.level in ["mix", "L1", "L2"]:
211
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
285
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
286
+ self.config.online_run_ut)
212
287
  api_register.api_modularity()
213
288
 
214
289
  if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task:
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,165 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from msprobe.visualization.graph.graph import Graph
18
+ from msprobe.visualization.graph.node_op import NodeOp
19
+ from msprobe.visualization.utils import save_json_file, GraphConst
20
+ from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
+ from msprobe.core.common.file_utils import load_json
22
+
23
+
24
+ class GraphBuilder:
25
+ @staticmethod
26
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
27
+ """
28
+ GraphBuilder的对外提供的构图方法
29
+ Args:
30
+ construct_path: construct.json路径
31
+ data_path: dump.json路径
32
+ stack_path: stack.json路径
33
+ model_name: 模型名字,依赖外部输入
34
+ Returns: Graph,代表图的数据结构
35
+ """
36
+ construct_dict = load_json(construct_path)
37
+ dump_dict = load_json(data_path)
38
+ stack_dict = load_json(stack_path)
39
+ data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
40
+ graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
41
+ GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
42
+ GraphBuilder._collect_apis_between_modules(graph)
43
+ return graph
44
+
45
+ @staticmethod
46
+ def to_json(filename, config):
47
+ """
48
+ 将graph导出成.vis文件的接口
49
+ """
50
+ result = {}
51
+ if config.graph_b:
52
+ result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
53
+ result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
54
+ else:
55
+ result = config.graph_n.to_dict()
56
+ if config.tool_tip:
57
+ result[GraphConst.JSON_TIP_KEY] = config.tool_tip
58
+ if config.node_colors:
59
+ result[GraphConst.COLORS] = config.node_colors
60
+ if config.micro_steps:
61
+ result[GraphConst.MICRO_STEPS] = config.micro_steps
62
+ if config.task:
63
+ result[GraphConst.JSON_TASK_KEY] = config.task
64
+ save_json_file(filename, result)
65
+
66
+ @staticmethod
67
+ def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
68
+ """
69
+ 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
70
+ """
71
+ # 匹配以.backward.后跟一个或多个数字结尾的模式
72
+ backward_pattern = r"(\.backward\.)(\d+)$"
73
+ forward_pattern = r"(\.forward\.)(\d+)$"
74
+ if re.search(backward_pattern, subnode_id) and not upnode_id:
75
+ forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
76
+ if forward_upnode_id:
77
+ new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
78
+ if new_upnode_id in construct_dict:
79
+ return new_upnode_id
80
+ return upnode_id
81
+
82
+ @staticmethod
83
+ def _init_nodes(graph, construct_dict, data_dict, stack_dict):
84
+ for subnode_id, upnode_id in construct_dict.items():
85
+ upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
86
+ if upnode_id:
87
+ upnode_op = NodeOp.get_node_op(upnode_id)
88
+ upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
89
+ else:
90
+ upnode = graph.root
91
+ node_op = NodeOp.get_node_op(subnode_id)
92
+ GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
93
+
94
+ @staticmethod
95
+ def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
96
+ if name in graph.node_map:
97
+ node = graph.get_node(name)
98
+ else:
99
+ graph.add_node(op, name, upnode)
100
+ node = graph.get_node(name)
101
+ node_data = data_stack_list[0].get(name, {})
102
+ node_stack_info = data_stack_list[1].get(name, [])
103
+ # 添加输入输出数据
104
+ input_data, output_data = get_input_output(node_data, node.id)
105
+ # 更新数据
106
+ node.set_input_output(input_data, output_data)
107
+ node.stack_info = node_stack_info
108
+ # 添加节点
109
+ node.add_upnode(upnode)
110
+ return node
111
+
112
+ @staticmethod
113
+ def _collect_apis_between_modules(graph):
114
+ """
115
+ 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
116
+ Args:
117
+ graph: 模型结构
118
+
119
+ Returns: None
120
+ """
121
+ i = 0
122
+ output = []
123
+ node_list = graph.root.subnodes
124
+ while i < len(node_list):
125
+ current_node = node_list[i]
126
+
127
+ # 当前节点为api,检查后续是否还有api
128
+ if current_node.op == NodeOp.function_api:
129
+ temp_nodes = [current_node]
130
+ i += 1
131
+ while i < len(node_list) and node_list[i].op == NodeOp.function_api:
132
+ temp_nodes.append(node_list[i])
133
+ i += 1
134
+
135
+ # 检查api节点是否大于等于2个
136
+ if len(temp_nodes) >= 2:
137
+ # 创建新节点,将这些api节点放入新节点的subnodes属性
138
+ node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
139
+ id_accumulation=True)
140
+ api_collection_node = graph.get_node(node_id)
141
+ api_collection_node.subnodes = temp_nodes
142
+ # 重新确立父子关系
143
+ for node in temp_nodes:
144
+ node.upnode = api_collection_node
145
+ api_collection_node.upnode = graph.root
146
+ output.append(api_collection_node)
147
+ else:
148
+ # 如果连续的api节点不足2个,将它们原样添加到输出列表
149
+ output.extend(temp_nodes)
150
+ else:
151
+ # 如果当前节点为module,直接添加到输出列表
152
+ output.append(current_node)
153
+ i += 1
154
+
155
+ graph.root.subnodes = output
156
+
157
+
158
+ class GraphExportConfig:
159
+ def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''):
160
+ self.graph_n = graph_n
161
+ self.graph_b = graph_b
162
+ self.tool_tip = tool_tip
163
+ self.node_colors = node_colors
164
+ self.micro_steps = micro_steps
165
+ self.task = task
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
+ from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
+ from msprobe.visualization.utils import GraphConst
20
+ from msprobe.core.common.const import Const
21
+
22
+ # 用于将节点名字解析成对应的NodeOp的规则
23
+ op_patterns = [
24
+ # NodeOp.module
25
+ r'^(Module.|Cell.)',
26
+ # NodeOp.function_api
27
+ r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
28
+ ]
29
+
30
+
31
+ def get_compare_mode(dump_path_param):
32
+ """
33
+ 获得比较模式,包括summary、MD5和真实数据三种模式
34
+ Args:
35
+ dump_path_param: 调用acc_compare接口所依赖的参数
36
+ Returns: 0 summary mode, 1 md5 mode, 2 true data mode
37
+ """
38
+ set_dump_path(dump_path_param)
39
+ dump_mode = get_dump_mode(dump_path_param)
40
+ compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
41
+ return compare_mode
42
+
43
+
44
+ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
45
+ """
46
+ 多进程运行生成真实数据
47
+ Args:
48
+ dump_path_param: 调用acc_compare接口所依赖的参数
49
+ csv_path: 生成文件路径
50
+ framework: 框架类型, pytorch或mindspore
51
+ is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
52
+ """
53
+ if framework == Const.PT_FRAMEWORK:
54
+ from msprobe.pytorch.compare.pt_compare import PTComparator
55
+ return PTComparator().do_multi_process(dump_path_param, csv_path)
56
+ else:
57
+ from msprobe.mindspore.compare.ms_compare import MSComparator
58
+ ms_comparator = MSComparator()
59
+ ms_comparator.cross_frame = is_cross_frame
60
+ return ms_comparator.do_multi_process(dump_path_param, csv_path)
61
+
62
+
63
+ def get_input_output(node_data, node_id):
64
+ """
65
+ 将dump的原始数据进行拆解,分解为output和input两个数据
66
+ Args:
67
+ node_data: 属于单个节点的dump数据
68
+ node_id: 节点名字
69
+ """
70
+ input_data = {}
71
+ output_data = {}
72
+ op_parsed_list = read_op(node_data, node_id)
73
+ for item in op_parsed_list:
74
+ full_op_name = item.get('full_op_name', '')
75
+ if not full_op_name:
76
+ continue
77
+ if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
78
+ output_data[full_op_name] = item
79
+ else:
80
+ name = item.get('data_name')
81
+ # 节点参数名称尽量使用落盘数据的名称
82
+ if isinstance(name, str) and name != '-1':
83
+ input_data[name.rsplit(Const.SEP, 1)[0]] = item
84
+ else:
85
+ input_data[full_op_name] = item
86
+ return input_data, output_data
87
+
88
+
89
+ def compare_data(data_dict_list1, data_dict_list2):
90
+ """
91
+ 比较get_input_output中输出的结果是否结构一致,比较一致返回True
92
+ """
93
+ if len(data_dict_list1) != len(data_dict_list2):
94
+ return False
95
+ # 用于比较两个节点是否相等的关键字段
96
+ tag_keys = ['type', 'shape']
97
+ for key1, key2 in zip(data_dict_list1, data_dict_list2):
98
+ dict1 = data_dict_list1[key1]
99
+ dict2 = data_dict_list2[key2]
100
+ for tag_key in tag_keys:
101
+ tag_value1 = dict1.get(tag_key, None)
102
+ tag_value2 = dict2.get(tag_key, None)
103
+ if tag_value1 != tag_value2:
104
+ return False
105
+ return True
106
+
107
+
108
+ def format_node_data(data_dict):
109
+ """
110
+ 批量进行节点数据的输出
111
+ """
112
+ del_list = ['requires_grad', 'full_op_name']
113
+ for _, value in data_dict.items():
114
+ if not isinstance(value, dict):
115
+ continue
116
+ for item in del_list:
117
+ if item in value:
118
+ del value[item]
119
+ _format_data(value)
120
+ return data_dict
121
+
122
+
123
+ def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
124
+ """
125
+ 调用acc_compare.py中的get_accuracy获得精度对比指标
126
+ 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
127
+ Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
128
+ """
129
+ merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
130
+ merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
131
+ result = []
132
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
133
+ get_accuracy(result, merge_n, merge_b, dump_mode)
134
+ return result
135
+
136
+
137
+ def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
138
+ """
139
+ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
140
+ """
141
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
142
+ op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
143
+ if node_id in stack_json_data:
144
+ op_parsed_list.append(
145
+ {'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
146
+ else:
147
+ op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
148
+ result = merge_tensor(op_parsed_list, dump_mode)
149
+ if not result:
150
+ result['op_name'] = []
151
+ return result
152
+
153
+
154
+ def _format_decimal_string(s):
155
+ """
156
+ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
157
+ """
158
+ pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
159
+ matches = pattern.findall(s)
160
+ for match in matches:
161
+ is_percent = match.endswith('%')
162
+ number_str = match.rstrip('%')
163
+ decimal_part = number_str.split('.')[1]
164
+ # 如果小数位数大于6,进行处理
165
+ if len(decimal_part) > GraphConst.ROUND_TH:
166
+ number_float = float(number_str)
167
+ formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
168
+ # 如果原来是百分数,加回百分号
169
+ if is_percent:
170
+ formatted_number += '%'
171
+ # 替换原字符串中的数值部分
172
+ s = s.replace(match, formatted_number)
173
+ return s
174
+
175
+
176
+ def _format_data(data_dict):
177
+ """
178
+ 格式化数据,小数保留6位,处理一些异常值
179
+ """
180
+ pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
181
+ all_null = False
182
+ for key, value in data_dict.items():
183
+ if isinstance(value, str):
184
+ # 将单引号删掉,None换成null避免前端解析错误
185
+ value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
186
+ value = _format_decimal_string(value)
187
+ elif value is None or value == ' ':
188
+ value = GraphConst.NULL
189
+ # 科学计数法1.123123123123e-11,格式化为1.123123e-11
190
+ elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
191
+ value = "{:.6e}".format(value)
192
+ elif isinstance(value, float):
193
+ value = round(value, GraphConst.ROUND_TH)
194
+ # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
195
+ if key != GraphConst.ERROR_KEY:
196
+ # 除了error_key不转str,其他都转str, 避免前端解析错误
197
+ value = str(value)
198
+ # max为null, 意味着这个参数值为null
199
+ if key == Const.MAX and value == GraphConst.NULL:
200
+ all_null = True
201
+ data_dict[key] = value
202
+ # 字典里的value全null,只保留一个null
203
+ if all_null:
204
+ data_dict.clear()
205
+ data_dict[GraphConst.VALUE] = GraphConst.NULL
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.