mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,130 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
17
+ from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
18
+ from msprobe.visualization.graph.graph import Graph, NodeOp
19
+ from msprobe.visualization.graph.node_colors import NodeColors
20
+ from msprobe.visualization.compare.mode_adapter import ModeAdapter
21
+ from msprobe.core.common.const import Const
22
+
23
+
24
+ class GraphComparator:
25
+ def __init__(self, graphs, dump_path_param, output_path, framework=Const.PT_FRAMEWORK, mapping_dict=None):
26
+ self.graph_n = graphs[0]
27
+ self.graph_b = graphs[1]
28
+ self._parse_param(dump_path_param, output_path)
29
+ self.framework = framework
30
+ self.mapping_dict = mapping_dict
31
+
32
+ def compare(self):
33
+ """
34
+ 比较函数,初始化结束后单独调用。比较结果写入graph_n
35
+ """
36
+ self._compare_nodes(self.graph_n.root)
37
+ self._postcompare()
38
+
39
+ def add_compare_result_to_node(self, node, compare_result_list):
40
+ """
41
+ 将比对结果添加到节点的输入输出数据中
42
+ Args:
43
+ node: 节点
44
+ compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list
45
+ """
46
+ # 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中
47
+ if self.ma.prepare_real_data(node):
48
+ return
49
+ compare_in_dict = {}
50
+ compare_out_dict = {}
51
+ # input和output对比数据分开
52
+ for item in compare_result_list:
53
+ if not isinstance(item, (list, tuple)) or not item:
54
+ continue
55
+ if '.output.' in item[0]:
56
+ compare_out_dict[item[0]] = item
57
+ else:
58
+ compare_in_dict[item[0]] = item
59
+ precision_index, other_dict = (
60
+ self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
61
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
62
+ node.data.update(other_dict)
63
+ if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
64
+ node.get_suggestions()
65
+
66
+ def _parse_param(self, dump_path_param, output_path):
67
+ self.dump_path_param = dump_path_param
68
+ self.output_path = output_path
69
+ compare_mode = get_compare_mode(self.dump_path_param)
70
+ self.ma = ModeAdapter(compare_mode)
71
+ self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
72
+ self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
73
+ self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
74
+
75
+ def _postcompare(self):
76
+ self._handle_api_collection_index()
77
+ if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
78
+ return
79
+ df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
80
+ df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
81
+ compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
82
+ for node in self.ma.compare_nodes:
83
+ precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
84
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
85
+ if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
86
+ node.get_suggestions()
87
+
88
+ def _handle_api_collection_index(self):
89
+ """
90
+ api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
91
+ md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
92
+ """
93
+ for node in self.graph_n.root.subnodes:
94
+ if node.op == NodeOp.api_collection:
95
+ precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
96
+ else GraphConst.MIN_INDEX_KEY
97
+ for api in node.subnodes:
98
+ precision_index = min(precision_index,
99
+ api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
100
+ if self.ma.compare_mode == GraphConst.MD5_COMPARE \
101
+ else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
102
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
103
+
104
+ def _compare_nodes(self, node_n):
105
+ """
106
+ 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
107
+ 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
108
+ """
109
+ if self.mapping_dict:
110
+ node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
111
+ if node_b:
112
+ ancestors_n.append(node_n.id)
113
+ ancestors_b.append(node_b.id)
114
+ node_n.matched_node_link = ancestors_b
115
+ node_b.matched_node_link = ancestors_n
116
+ else:
117
+ node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
118
+ if node_b:
119
+ ancestors.append(node_b.id)
120
+ node_n.add_link(node_b, ancestors)
121
+ if node_b:
122
+ # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
123
+ compare_result_list = compare_node([node_n.id, node_b.id],
124
+ [self.data_n_dict, self.data_b_dict],
125
+ self.stack_json_data, self.ma.compare_mode)
126
+ if compare_result_list:
127
+ self.ma.add_csv_data(compare_result_list)
128
+ self.add_compare_result_to_node(node_n, compare_result_list)
129
+ for subnode in node_n.subnodes:
130
+ self._compare_nodes(subnode)
@@ -0,0 +1,211 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import math
18
+ from msprobe.core.common.const import CompareConst, Const
19
+ from msprobe.visualization.utils import ToolTip, GraphConst, str2float
20
+
21
+
22
+ class ModeAdapter:
23
+ def __init__(self, compare_mode):
24
+ self.compare_mode = compare_mode
25
+ self.csv_data = []
26
+ self.compare_nodes = []
27
+
28
+ @staticmethod
29
+ def _add_md5_compare_data(node_data, compare_data_dict):
30
+ precision_index = GraphConst.MAX_INDEX_KEY
31
+ for key, value in node_data.items():
32
+ if not isinstance(value, dict):
33
+ continue
34
+ compare_data = compare_data_dict.get(key)
35
+ if compare_data:
36
+ headers = CompareConst.MD5_COMPARE_RESULT_HEADER
37
+ id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST]
38
+ ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list)
39
+ # md5比对是否通过
40
+ if value.get(CompareConst.RESULT) != CompareConst.PASS:
41
+ precision_index = GraphConst.MIN_INDEX_KEY
42
+ node_data[key] = value
43
+ return precision_index
44
+
45
+ @staticmethod
46
+ def _add_real_compare_data(node_data, compare_data_dict):
47
+ min_thousandth = float(1)
48
+ numbers = []
49
+ for key, value in node_data.items():
50
+ if not isinstance(value, dict):
51
+ continue
52
+ compare_data = compare_data_dict.get(key)
53
+ if compare_data:
54
+ headers = CompareConst.COMPARE_RESULT_HEADER
55
+ id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST]
56
+ ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list)
57
+ # 跳过scalar data,因为无法计算双千指标,会得到Nan
58
+ if not value.get(Const.SHAPE):
59
+ continue
60
+ # 获取一个节点所有的输入或输出最小的双千指标
61
+ thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO)
62
+ # 可能是None,可能是非数字内容str
63
+ try:
64
+ thousandth = float(thousandth)
65
+ except (ValueError, TypeError):
66
+ thousandth = None
67
+ if thousandth is not None:
68
+ numbers.append(thousandth)
69
+ node_data[key] = value
70
+ # 双千指标都是None的异常情况
71
+ if not numbers:
72
+ min_thousandth = None
73
+ else:
74
+ min_thousandth = min(numbers + [min_thousandth])
75
+ return min_thousandth
76
+
77
+ @staticmethod
78
+ def _add_summary_compare_data(node_data, compare_data_dict):
79
+ max_relative_err = GraphConst.MIN_INDEX_KEY
80
+ # data_info: {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [2, 536320], 'Max': 9.66036224, ...}
81
+ for key, data_info in node_data.items():
82
+ if not isinstance(data_info, dict):
83
+ continue
84
+ compare_data = compare_data_dict.get(key)
85
+ if compare_data:
86
+ dtype = data_info.get(Const.DTYPE)
87
+ # 对应比对结果csv的列
88
+ key_list = GraphConst.SUMMARY_INDEX_LIST
89
+ headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
90
+ id_list = [headers.index(x) for x in key_list]
91
+ ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
92
+ for index, item in enumerate(key_list[4:]):
93
+ value = data_info.get(GraphConst.VALUE_INDEX_LIST[index])
94
+ value_diff = data_info.get(key_list[index])
95
+ relative_err = str2float(data_info.get(item))
96
+ if isinstance(value, float) and isinstance(value_diff, float) \
97
+ and dtype in GraphConst.SMALL_VALUES.keys():
98
+ small_value = GraphConst.SMALL_VALUES.get(dtype)
99
+ # 小值域
100
+ if abs(value) <= small_value:
101
+ data_info[item] = ToolTip.SMALL_VALUE_TIP.format(data_info.get(item),
102
+ GraphConst.VALUE_INDEX_LIST[index],
103
+ small_value)
104
+ relative_err = GraphConst.MIN_INDEX_KEY \
105
+ if abs(value_diff) <= GraphConst.SMALL_VALUES_ABS_ERROR.get(dtype) \
106
+ else GraphConst.MAX_INDEX_KEY
107
+ max_relative_err = max(max_relative_err, relative_err)
108
+ node_data[key] = data_info
109
+ max_relative_err = 1 if max_relative_err > 1 else max_relative_err
110
+ return max_relative_err
111
+
112
+ @staticmethod
113
+ def _match_data(data_dict, compare_data, key_list, id_list):
114
+ """
115
+ 绑定精度指标到node的input_data和output_data
116
+ """
117
+ if len(key_list) != len(id_list):
118
+ return
119
+ for id_val, key in zip(id_list, key_list):
120
+ data_dict[key] = compare_data[id_val]
121
+
122
+ @staticmethod
123
+ def _check_list_len(data_list, len_num):
124
+ if len(data_list) < len_num:
125
+ raise ValueError(f"compare_data_dict_list must contain at least {len_num} items.")
126
+
127
+ def parse_result(self, node, compare_data_dict_list):
128
+ """
129
+ 根据结果返回数据,分别是precision_index,和附加数据
130
+ """
131
+
132
+ other_dict = {}
133
+ if self.compare_mode == GraphConst.MD5_COMPARE:
134
+ ModeAdapter._check_list_len(compare_data_dict_list, 2)
135
+ precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict_list[0])
136
+ precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict_list[1])
137
+ # 所有输入输出md5对比通过,这个节点才算通过
138
+ precision_index = min(precision_index_in, precision_index_out)
139
+ other_result = CompareConst.PASS if precision_index == GraphConst.MAX_INDEX_KEY else CompareConst.DIFF
140
+ other_dict[CompareConst.RESULT] = other_result
141
+ elif self.compare_mode == GraphConst.SUMMARY_COMPARE:
142
+ ModeAdapter._check_list_len(compare_data_dict_list, 2)
143
+ ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict_list[0])
144
+ precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict_list[1])
145
+ precision_index = precision_index_out
146
+ else:
147
+ ModeAdapter._check_list_len(compare_data_dict_list, 1)
148
+ min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
149
+ min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
150
+ if min_thousandth_in is not None and min_thousandth_out is not None:
151
+ change_percentage = min_thousandth_in - min_thousandth_out
152
+ else:
153
+ change_percentage = GraphConst.MIN_INDEX_KEY
154
+ change_percentage = GraphConst.MIN_INDEX_KEY if change_percentage < GraphConst.MIN_INDEX_KEY \
155
+ else change_percentage
156
+ precision_index = GraphConst.MAX_INDEX_KEY \
157
+ if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
158
+ return precision_index, other_dict
159
+
160
+ def prepare_real_data(self, node):
161
+ """
162
+ 为真实数据比较模式准备节点信息
163
+ """
164
+ if self.compare_mode == GraphConst.REAL_DATA_COMPARE:
165
+ self.compare_nodes.append(node)
166
+ return True
167
+ return False
168
+
169
+ def add_csv_data(self, compare_result_list):
170
+ if self.compare_mode != GraphConst.REAL_DATA_COMPARE:
171
+ return
172
+ self.csv_data.extend(compare_result_list)
173
+
174
+ def add_error_key(self, node_data):
175
+ """
176
+ 根据不同的模式进行提供不同错误信息
177
+ """
178
+ for key, value in node_data.items():
179
+ if not isinstance(value, dict):
180
+ continue
181
+ if self.compare_mode == GraphConst.SUMMARY_COMPARE:
182
+ message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
183
+ CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
184
+ elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
185
+ message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
186
+ else:
187
+ # 输出件优化
188
+ message = []
189
+ value[GraphConst.ERROR_KEY] = message
190
+ node_data[key] = value
191
+
192
+ def get_tool_tip(self):
193
+ """
194
+ 用于前端展示字段的具体含义
195
+ """
196
+ if self.compare_mode == GraphConst.SUMMARY_COMPARE:
197
+ tips = {
198
+ CompareConst.MAX_DIFF: ToolTip.MAX_DIFF,
199
+ CompareConst.MIN_DIFF: ToolTip.MIN_DIFF,
200
+ CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF,
201
+ CompareConst.NORM_DIFF: ToolTip.NORM_DIFF}
202
+ elif self.compare_mode == GraphConst.MD5_COMPARE:
203
+ tips = {Const.MD5: ToolTip.MD5}
204
+ else:
205
+ tips = {
206
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO,
207
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO,
208
+ CompareConst.COSINE: ToolTip.COSINE,
209
+ CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
210
+ CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
211
+ return json.dumps(tips)
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,124 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from msprobe.core.overflow_check.level import OverflowLevel
16
+ from msprobe.visualization.graph.node_op import NodeOp
17
+ from msprobe.visualization.utils import Suggestions, GraphConst
18
+ from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data
19
+
20
+
21
+ class BaseNode:
22
+ def __init__(self, node_op, node_id, up_node=None):
23
+ self.op = node_op
24
+ self.id = node_id
25
+ self.data = {}
26
+ self.output_data = {}
27
+ self.input_data = {}
28
+ self.upnode = None
29
+ self.add_upnode(up_node)
30
+ self.subnodes = []
31
+ self.matched_node_link = []
32
+ self.suggestions = {}
33
+ self.stack_info = []
34
+ self.micro_step_id = None
35
+ self.overflow_level = None
36
+
37
+ def __str__(self):
38
+ info = f'id:\t{self.id}'
39
+ return info
40
+
41
+ def __eq__(self, other):
42
+ """
43
+ 用来判断两个节点是否可以被匹配上,认为结构上是否一致
44
+ """
45
+ if not compare_data(self.input_data, other.input_data):
46
+ return False
47
+ if not compare_data(self.output_data, other.output_data):
48
+ return False
49
+ return True
50
+
51
+ def get_suggestions(self):
52
+ """
53
+ 精度疑似有问题时,提供一些建议
54
+ """
55
+ if self.op == NodeOp.module:
56
+ self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module
57
+ self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
58
+ elif self.op == NodeOp.function_api:
59
+ self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
60
+ self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
61
+
62
+ def set_input_output(self, input_data, output_data):
63
+ self.input_data = input_data
64
+ self.output_data = output_data
65
+
66
+ def set_overflow_level(self, level):
67
+ if not level or not isinstance(level, OverflowLevel):
68
+ return
69
+ self.overflow_level = level
70
+
71
+ def add_upnode(self, node):
72
+ """
73
+ 绑定upnode,用于对两个节点进行上下级关联
74
+ """
75
+ if not node or node.id == self.id or self.upnode:
76
+ return
77
+ self.upnode = node
78
+ node.subnodes.append(self)
79
+
80
+ def add_link(self, node, ancestors):
81
+ """
82
+ 在节点匹配成功后进行匹配数据的录入
83
+ Args:
84
+ node: 和self相互匹配的节点
85
+ ancestors: 对面节点的祖先信息
86
+ """
87
+ self.matched_node_link = ancestors
88
+ node.matched_node_link = ancestors
89
+
90
+ def to_dict(self):
91
+ """
92
+ 输出数据
93
+ """
94
+ result = {
95
+ 'id': self.id,
96
+ 'node_type': self.op.value,
97
+ 'output_data': format_node_data(self.output_data),
98
+ 'input_data': format_node_data(self.input_data),
99
+ 'upnode': self.upnode.id if self.upnode else 'None',
100
+ 'subnodes': [node.id for node in self.subnodes],
101
+ 'matched_node_link': self.matched_node_link,
102
+ 'suggestions': self.suggestions,
103
+ 'stack_info': self.stack_info
104
+ }
105
+ if self.micro_step_id is not None:
106
+ result['micro_step_id'] = self.micro_step_id
107
+ # 是否存在overflow,并保存结果
108
+ if self.overflow_level and isinstance(self.overflow_level, OverflowLevel):
109
+ if self.data is None:
110
+ self.data = dict()
111
+ self.data['overflow_level'] = self.overflow_level.value
112
+ result['data'] = self.data
113
+ return result
114
+
115
+ def get_ancestors(self):
116
+ """
117
+ 获取节点所有祖先的列表
118
+ """
119
+ ancestors = []
120
+ current_node = self.upnode
121
+ while current_node:
122
+ ancestors.append(current_node.id)
123
+ current_node = current_node.upnode
124
+ return list(reversed(ancestors))
@@ -0,0 +1,200 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from msprobe.core.overflow_check.checker import AnomalyDetector
16
+ from msprobe.visualization.graph.base_node import BaseNode
17
+ from msprobe.visualization.graph.node_op import NodeOp
18
+ from msprobe.visualization.utils import GraphConst
19
+ from msprobe.core.common.log import logger
20
+ from msprobe.core.common.const import Const
21
+
22
+
23
+ MAX_RECUR_LEVEL = 100
24
+
25
+
26
+ class Graph:
27
+ def __init__(self, model_name, data_path='', dump_data=None):
28
+ self.node_map = {}
29
+ self.node_id_map = {}
30
+ self.add_node(NodeOp.module, model_name)
31
+ self.root = self.get_node(model_name)
32
+ self.data_path = data_path
33
+ self.dump_data = dump_data
34
+
35
+ def __str__(self):
36
+ infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
37
+ info = "\n".join(infos)
38
+ return info
39
+
40
+ @staticmethod
41
+ def match(graph_n, node_n, graph_b):
42
+ """
43
+ 给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成
44
+ 目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑
45
+ 返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, []
46
+ """
47
+ if not node_n or node_n.id not in graph_b.node_map:
48
+ return None, []
49
+ node_b = graph_b.node_map.get(node_n.id)
50
+ if node_n != node_b:
51
+ return None, []
52
+ ancestors_n = node_n.get_ancestors()
53
+ ancestors_b = node_b.get_ancestors()
54
+ if ancestors_n != ancestors_b:
55
+ return None, []
56
+ return node_b, ancestors_n
57
+
58
+ @staticmethod
59
+ def mapping_match(node_n, graph_b, mapping_dict):
60
+ """
61
+ 根据映射配置对节点进行匹配
62
+ """
63
+ node_b = graph_b.node_map.get(mapping_dict.get(node_n.id, node_n.id))
64
+ if not node_b:
65
+ return None, [], []
66
+ ancestors_n = node_n.get_ancestors()
67
+ ancestors_b = node_b.get_ancestors()
68
+ return node_b, ancestors_n, ancestors_b
69
+
70
+ @staticmethod
71
+ def dfs(node, result):
72
+ info = node.to_dict()
73
+ result[node.id] = info
74
+ for subnode in node.subnodes:
75
+ Graph.dfs(subnode, result)
76
+
77
+ @staticmethod
78
+ def split_nodes_by_micro_step(nodes):
79
+ """
80
+ 根据Module名称, 区分一个step中的多个micro steps.
81
+ 一个micro step必须是一次完整的前反向过程
82
+ Example::
83
+ =============== micro step0
84
+ Module.forward
85
+ Module.forward
86
+ ...
87
+ Module.backward
88
+ Module.backward
89
+ =============== micro step1
90
+ Module.forward
91
+ Module.forward
92
+ ...
93
+ Module.backward
94
+ Module.backward
95
+ =============== micro step2
96
+ Module.forward
97
+ Module.forward
98
+ ...
99
+ Module.backward
100
+ Module.backward
101
+
102
+ 如果是非Module节点,分类到前一个Module节点所在的micro step.
103
+ """
104
+ result = {}
105
+ micro_step = 0
106
+ result[micro_step] = []
107
+ backward_flag = False
108
+
109
+ for node in nodes:
110
+ if node.op == NodeOp.module:
111
+ if f'{Const.SEP}{Const.FORWARD}{Const.SEP}' in node.id:
112
+ if backward_flag:
113
+ micro_step += 1
114
+ result[micro_step] = []
115
+ backward_flag = False
116
+ else:
117
+ backward_flag = True
118
+ result[micro_step].append(node)
119
+ return result
120
+
121
+ def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
122
+ """
123
+ 在graph中进行节点的添加
124
+ Args:
125
+ node_op: 需要添加的节点类型
126
+ node_id: 需要添加的节点id
127
+ up_node:对应节点的父节点
128
+ id_accumulation: 是否对传入的重复node_id进行累加
129
+ """
130
+ if node_id in self.node_map:
131
+ if id_accumulation:
132
+ self.node_id_map[node_id] = 0
133
+ else:
134
+ return node_id
135
+ if id_accumulation:
136
+ if node_id in self.node_id_map:
137
+ self.node_id_map[node_id] += 1
138
+ else:
139
+ self.node_id_map[node_id] = 0
140
+ node_id = f'{node_id}.{self.node_id_map[node_id]}'
141
+ node = BaseNode(node_op, node_id, up_node)
142
+ self.node_map[node_id] = node
143
+ return node_id
144
+
145
+ def get_node(self, node_id):
146
+ """
147
+ 返回节点,不存在返回None
148
+ """
149
+ return self.node_map.get(node_id, None)
150
+
151
+ def to_dict(self):
152
+ """
153
+ 用于数据输出
154
+ """
155
+ result = {}
156
+ result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
157
+ result[GraphConst.JSON_DATA_KEY] = self.data_path
158
+ result[GraphConst.JSON_NODE_KEY] = {}
159
+ for node_id in self.node_map:
160
+ info = self.node_map.get(node_id).to_dict()
161
+ result[GraphConst.JSON_NODE_KEY][node_id] = info
162
+ return result
163
+
164
+ def paging_by_micro_step(self, graph_other=None):
165
+ """
166
+ 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
167
+ 比对场景中,同步更新另一个图graph_other中相应节点的micro step信息
168
+ Args:
169
+ self: 当前graph
170
+ graph_other: 可选参数,另一个graph
171
+ Returns: 分批的数量
172
+ """
173
+ batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
174
+ for batch_number, nodes in batches_n.items():
175
+ for node in nodes:
176
+ node.micro_step_id = batch_number
177
+ # 在graph_other中更新已匹配节点的micro_step_id
178
+ if graph_other and node.matched_node_link:
179
+ node_other = graph_other.get_node(node.matched_node_link[-1])
180
+ if node_other:
181
+ node_other.micro_step_id = batch_number
182
+ # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
183
+ if graph_other:
184
+ for node in graph_other.root.subnodes:
185
+ if node.micro_step_id is None:
186
+ try:
187
+ micro_step_id = int(node.id.split(Const.SEP)[-1])
188
+ except ValueError:
189
+ micro_step_id = 0
190
+ node.micro_step_id = micro_step_id
191
+ return len(batches_n)
192
+
193
+ def overflow_check(self):
194
+ detector = AnomalyDetector(self.dump_data)
195
+ detector.analyze().filter()
196
+
197
+ for node_id, _node in self.node_map.items():
198
+ if detector.has_overflow(node_id):
199
+ lv = detector.get_overflow_level(node_id)
200
+ _node.set_overflow_level(lv)