mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,12 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import copy
2
- import csv
3
17
  import glob
4
18
  import os
19
+ import re
5
20
 
6
21
  import numpy as np
7
22
  import pandas as pd
8
- from msprobe.core.common.const import CompareConst, GraphMode, Const, FileCheckConst
9
- from msprobe.core.common.file_utils import FileOpen, check_path_before_create, change_mode, load_npy
23
+ from msprobe.core.common.const import CompareConst, GraphMode, Const
24
+ from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
10
25
  from msprobe.core.common.log import logger
11
26
  from msprobe.core.common.utils import add_time_with_xlsx, CompareException
12
27
  from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
@@ -14,7 +29,7 @@ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_che
14
29
  from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
15
30
 
16
31
 
17
- class row_data:
32
+ class RowData:
18
33
  def __init__(self, mode):
19
34
  self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
20
35
  self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
@@ -28,17 +43,34 @@ class row_data:
28
43
  return self.data
29
44
 
30
45
 
46
+ def get_name_dict(name: str) -> dict:
47
+ compare_pattern = re.compile(r'^([^.]+)\.([^.]+)\.([^.]+)\.([^.]+)\.(\d+(?:\.\d+)*)\.'
48
+ r'((?:in|out)put(?:\.\d+)*)\.([^.]+)\.([^.]+)\.npy$')
49
+ match = compare_pattern.match(name)
50
+ if match:
51
+ return {'op_type': match.group(1),
52
+ 'op_name': match.group(2),
53
+ 'task_id': match.group(3),
54
+ 'stream_id': match.group(4),
55
+ 'timestamp': match.group(5).split(Const.SEP)[0],
56
+ 'input_output_index': match.group(6),
57
+ 'slot': match.group(7),
58
+ 'format': match.group(8)}
59
+ return {}
60
+
61
+
31
62
  def npy_data_read(data_path, npy_file_list, mapping_dict):
32
63
  data_list = []
64
+ compare_key_elements = ['op_name', 'task_id', 'input_output_index', 'slot']
33
65
  for data in npy_file_list:
34
66
  if data in mapping_dict:
35
- split_list = mapping_dict[data].split(Const.SEP)
67
+ name_dict = get_name_dict(mapping_dict[data])
36
68
  else:
37
- split_list = data.split(Const.SEP)
38
- if len(split_list) < 7:
69
+ name_dict = get_name_dict(data)
70
+ if not name_dict:
39
71
  continue
40
- compare_key = f"{split_list[1]}.{split_list[2]}.{split_list[3]}.{split_list[5]}.{split_list[6]}"
41
- timestamp = convert_to_int(split_list[4])
72
+ compare_key = Const.SEP.join([name_dict.get(element) for element in compare_key_elements])
73
+ timestamp = convert_to_int(name_dict.get('timestamp'))
42
74
 
43
75
  data_list.append([os.path.join(data_path, data), compare_key, timestamp])
44
76
  return data_list
@@ -47,17 +79,18 @@ def npy_data_read(data_path, npy_file_list, mapping_dict):
47
79
  def statistic_data_read(statistic_file_list, statistic_file_path):
48
80
  data_list = []
49
81
  statistic_data_list = []
50
- header_index = {'Data Type': None, 'Shape': None, 'Max Value': None, 'Min Value': None,
51
- 'Avg Value': None, 'L2Norm Value': None}
82
+ header_index = {
83
+ 'Data Type': None, 'Shape': None, 'Max Value': None,
84
+ 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
85
+ }
52
86
  for statistic_file in statistic_file_list:
53
- with FileOpen(statistic_file, "r") as f:
54
- csv_reader = csv.reader(f, delimiter=",")
55
- header = next(csv_reader)
56
- for key in header_index.keys():
57
- for index, value in enumerate(header):
58
- if key == value:
59
- header_index[key] = index
60
- statistic_data_list.extend([row for row in csv_reader])
87
+ content = read_csv(statistic_file, as_pd=False)
88
+ header = content[0]
89
+ for key in header_index.keys():
90
+ for index, value in enumerate(header):
91
+ if key == value:
92
+ header_index[key] = index
93
+ statistic_data_list.extend(content[1:])
61
94
 
62
95
  for key in header_index.keys():
63
96
  if header_index[key] is None:
@@ -65,8 +98,9 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
65
98
 
66
99
  for data in statistic_data_list:
67
100
  compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}"
101
+ op_name = f"{compare_key} {statistic_file_path}"
68
102
  timestamp = int(data[4])
69
- result_data = [statistic_file_path, compare_key, timestamp]
103
+ result_data = [op_name, compare_key, timestamp]
70
104
  for key in header_index.keys():
71
105
  if header_index[key] is None:
72
106
  result_data.append(np.nan)
@@ -94,11 +128,9 @@ def generate_data_name(data_path):
94
128
  mapping_dict = {}
95
129
  if mapping_exist:
96
130
  for mapping_file in mapping_file_list:
97
- with FileOpen(mapping_file, "r") as f:
98
- csv_reader = csv.reader(f, delimiter=",")
99
- header = next(csv_reader)
100
- for row in csv_reader:
101
- mapping_dict[row[0]] = row[1]
131
+ content = read_csv(mapping_file, False)
132
+ for row in content[1:]:
133
+ mapping_dict[row[0]] = row[1]
102
134
 
103
135
  if npy_exist:
104
136
  data_list = npy_data_read(data_path, npy_file_list, mapping_dict)
@@ -133,7 +165,7 @@ class GraphMSComparator:
133
165
  def compare_ops(compare_result_db, mode):
134
166
 
135
167
  def npy_mode_compute(row):
136
- result_dict = row_data(GraphMode.NPY_MODE)()
168
+ result_dict = RowData(GraphMode.NPY_MODE)()
137
169
 
138
170
  def process_npy_file(file_path, name_prefix, result):
139
171
  if os.path.exists(file_path):
@@ -168,7 +200,7 @@ class GraphMSComparator:
168
200
  return pd.Series(result_dict)
169
201
 
170
202
  def statistic_mode_compute(row):
171
- result_dict = row_data('STATISTIC')()
203
+ result_dict = RowData('STATISTIC')()
172
204
 
173
205
  def update_result_dict(result, rows, prefix):
174
206
  result[f'{prefix} Name'] = rows[f'{prefix} Name']
@@ -195,24 +227,30 @@ class GraphMSComparator:
195
227
  result_dict[CompareConst.NPU_NORM] - result_dict[CompareConst.BENCH_NORM])
196
228
  result_dict[CompareConst.MAX_RELATIVE_ERR] = result_dict[CompareConst.MAX_DIFF] / result_dict[
197
229
  CompareConst.BENCH_MAX] if result_dict[CompareConst.BENCH_MAX] > 0 else 0
198
- result_dict[CompareConst.MAX_RELATIVE_ERR] = str(result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
230
+ if not np.isnan(result_dict[CompareConst.MAX_RELATIVE_ERR]):
231
+ result_dict[CompareConst.MAX_RELATIVE_ERR] = str(
232
+ result_dict[CompareConst.MAX_RELATIVE_ERR] * 100) + "%"
199
233
  result_dict[CompareConst.MIN_RELATIVE_ERR] = result_dict[CompareConst.MIN_DIFF] / result_dict[
200
234
  CompareConst.BENCH_MIN] if result_dict[CompareConst.BENCH_MIN] > 0 else 0
201
- result_dict[CompareConst.MIN_RELATIVE_ERR] = str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
235
+ if not np.isnan(result_dict[CompareConst.MIN_RELATIVE_ERR]):
236
+ result_dict[CompareConst.MIN_RELATIVE_ERR] = \
237
+ str(result_dict[CompareConst.MIN_RELATIVE_ERR] * 100) + "%"
202
238
  result_dict[CompareConst.MEAN_RELATIVE_ERR] = result_dict[CompareConst.MEAN_DIFF] / result_dict[
203
239
  CompareConst.BENCH_MEAN] if result_dict[CompareConst.BENCH_MEAN] > 0 else 0
204
- result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
205
- result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
240
+ if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
241
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
242
+ result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
206
243
  result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
207
244
  CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
208
- result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
209
- result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
245
+ if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
246
+ result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
247
+ result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
210
248
  magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
211
249
  max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
212
- if magnitude_diff > CompareConst.MAGNITUDE:
213
- result_dict[CompareConst.ACCURACY] = 'No'
214
- else:
215
- result_dict[CompareConst.ACCURACY] = 'Yes'
250
+ if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
251
+ magnitude_diff = 0
252
+ result_dict[CompareConst.ACCURACY] = CompareConst.YES if \
253
+ magnitude_diff <= CompareConst.MAGNITUDE else CompareConst.NO
216
254
 
217
255
  return pd.Series(result_dict)
218
256
 
@@ -235,14 +273,24 @@ class GraphMSComparator:
235
273
  is_empty = True
236
274
  if is_empty or not mode:
237
275
  continue
238
- compare_result_df = self._do_multi_process(compare_result_df, mode)
276
+ compare_result_df = self.do_multi_process(compare_result_df, mode)
239
277
  compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
240
278
  compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
241
- check_path_before_create(compare_result_path)
242
- compare_result_df.to_excel(compare_result_path, index=False)
243
- change_mode(compare_result_path, FileCheckConst.DATA_FILE_AUTHORITY)
279
+ self.to_excel(compare_result_df, compare_result_path)
244
280
  logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
245
281
 
282
+ def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
283
+ size = len(compare_result_df)
284
+ # sheet size cannot be larger than 1048576
285
+ if size < CompareConst.MAX_EXCEL_LENGTH:
286
+ compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
287
+ need_slice else compare_result_path
288
+ save_excel(compare_result_path, compare_result_df)
289
+ return slice_num + 1
290
+ else:
291
+ slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
292
+ return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
293
+
246
294
  def compare_process(self, rank_id, step_id):
247
295
  # generate data_path
248
296
  npu_data_path_list = self.npu_rank_step_dict.get((rank_id, step_id))
@@ -251,8 +299,8 @@ class GraphMSComparator:
251
299
  return [], ''
252
300
 
253
301
  # generate file name
254
- npu_mode = 'ERROR_MODE'
255
- bench_mode = 'ERROR_MODE'
302
+ npu_mode = GraphMode.ERROR_MODE
303
+ bench_mode = GraphMode.ERROR_MODE
256
304
  npu_data_list = []
257
305
  bench_data_list = []
258
306
  for npu_data_path in npu_data_path_list:
@@ -262,7 +310,7 @@ class GraphMSComparator:
262
310
  bench_mode, data_list = generate_data_name(bench_data_path)
263
311
  bench_data_list.extend(data_list)
264
312
 
265
- if npu_mode == "ERROR_MODE" or bench_mode == "ERROR_MODE":
313
+ if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
266
314
  logger.warning(f"Data_path {npu_data_path} or {bench_data_path} is not exist.")
267
315
  return [], ''
268
316
  if npu_mode != bench_mode:
@@ -286,11 +334,13 @@ class GraphMSComparator:
286
334
  CompareConst.BENCH_NORM])
287
335
 
288
336
  npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
289
- npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(np.float32)
337
+ npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
290
338
 
291
- bench_float_type = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
292
- CompareConst.BENCH_NORM]
293
- bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(np.float32)
339
+ bench_float_type = [
340
+ CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
341
+ CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
342
+ ]
343
+ bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
294
344
 
295
345
  npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
296
346
  bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
@@ -339,7 +389,7 @@ class GraphMSComparator:
339
389
  rank_step_path_dict[rank_step_key] = [dir_path]
340
390
  return dict(sorted(rank_step_path_dict.items()))
341
391
 
342
- def _do_multi_process(self, result_df, mode):
392
+ def do_multi_process(self, result_df, mode):
343
393
  try:
344
394
  result_df = _ms_graph_handle_multi_process(self.compare_ops, result_df, mode)
345
395
  except ValueError as e:
@@ -1,9 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import create_directory
4
20
  from msprobe.mindspore.common.const import Const as MsConst
5
21
  from msprobe.mindspore.common.const import FreeBenchmarkConst
6
- from msprobe.core.common.file_utils import create_directory
7
22
 
8
23
 
9
24
  class DebuggerConfig:
@@ -18,7 +33,7 @@ class DebuggerConfig:
18
33
  self.level_ori = common_config.level
19
34
  self.list = [] if not task_config.list else task_config.list
20
35
  self.scope = [] if not task_config.scope else task_config.scope
21
- self.data_mode = [] if not task_config.data_mode else task_config.data_mode
36
+ self.data_mode = [Const.ALL] if not task_config.data_mode else task_config.data_mode
22
37
  self.file_format = task_config.file_format
23
38
  self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
24
39
  self.check_mode = task_config.check_mode
@@ -37,6 +52,9 @@ class DebuggerConfig:
37
52
  self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
38
53
  raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
39
54
  f"but got {self.pert_type}.")
55
+ if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
56
+ raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
57
+ f"but got {self.handler_type}.")
40
58
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
41
59
 
42
60
  def check(self):
@@ -51,16 +69,4 @@ class DebuggerConfig:
51
69
  self.file_format = "npy"
52
70
  if not self.check_mode:
53
71
  self.check_mode = "all"
54
- self._check_rank()
55
- self._check_step()
56
72
  return True
57
-
58
- def _check_rank(self):
59
- for rank_id in self.rank:
60
- if not isinstance(rank_id, int) or rank_id < 0:
61
- raise ValueError(f"rank {self.rank} must be a positive integer.")
62
-
63
- def _check_step(self):
64
- for s in self.step:
65
- if not isinstance(s, int) or s < 0:
66
- raise ValueError(f"step element {s} must be a positive integer.")
@@ -1,17 +1,34 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
17
+ from collections import defaultdict
2
18
 
3
19
  import mindspore as ms
4
20
  from mindspore._c_expression import MSContext
5
21
 
6
- from msprobe.mindspore.service import Service
7
- from msprobe.mindspore.ms_config import parse_json_config
8
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
9
- from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
10
- from msprobe.core.common.const import Const
22
+ from msprobe.core.common.const import Const, MsgConst
23
+ from msprobe.mindspore.cell_processor import CellProcessor
11
24
  from msprobe.mindspore.common.const import Const as MsConst
12
- from msprobe.mindspore.runtime import Runtime
13
-
25
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
26
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
14
27
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
28
+ from msprobe.mindspore.ms_config import parse_json_config
29
+ from msprobe.mindspore.runtime import Runtime
30
+ from msprobe.mindspore.service import Service
31
+ from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
15
32
 
16
33
 
17
34
  class PrecisionDebugger:
@@ -65,11 +82,11 @@ class PrecisionDebugger:
65
82
  def start(cls, model=None):
66
83
  instance = cls._instance
67
84
  if not instance:
68
- raise Exception("No instance of PrecisionDebugger found.")
85
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
69
86
  if instance.task in PrecisionDebugger.task_not_need_service:
70
87
  return
71
88
 
72
- instance.config.execution_mode = instance._get_execution_mode()
89
+ instance.config.execution_mode = cls._get_execution_mode()
73
90
  if cls._need_service():
74
91
  if not instance.service:
75
92
  instance.service = Service(instance.config)
@@ -82,11 +99,21 @@ class PrecisionDebugger:
82
99
  instance.first_start = True
83
100
  Runtime.is_running = True
84
101
 
102
+ @classmethod
103
+ def forward_backward_dump_end(cls):
104
+ instance = cls._instance
105
+ if not instance:
106
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
107
+ if instance.task in PrecisionDebugger.task_not_need_service:
108
+ return
109
+ if instance.service:
110
+ instance.service.forward_backward_dump_end()
111
+
85
112
  @classmethod
86
113
  def stop(cls):
87
114
  instance = cls._instance
88
115
  if not instance:
89
- raise Exception("PrecisionDebugger instance is not created.")
116
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
90
117
  if instance.task == Const.GRAD_PROBE:
91
118
  instance.gm.stop()
92
119
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -99,18 +126,21 @@ class PrecisionDebugger:
99
126
  def step(cls):
100
127
  instance = cls._instance
101
128
  if not instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
129
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
103
130
  if instance.task in PrecisionDebugger.task_not_need_service:
104
131
  return
105
132
  if instance.service:
106
133
  instance.service.step()
134
+ HOOKCell.cell_count = defaultdict(int)
135
+ CellProcessor.reset_cell_stats()
136
+
107
137
  Runtime.step_count += 1
108
138
 
109
139
  @classmethod
110
140
  def monitor(cls, opt):
111
141
  instance = cls._instance
112
142
  if not instance:
113
- raise Exception("PrecisionDebugger instance is not created.")
143
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
114
144
  if instance.task != Const.GRAD_PROBE:
115
145
  return
116
146
  instance.gm.monitor(opt)
@@ -119,7 +149,7 @@ class PrecisionDebugger:
119
149
  def _need_service(cls):
120
150
  instance = cls._instance
121
151
  if not instance:
122
- raise Exception("No instance of PrecisionDebugger found.")
152
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
123
153
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
124
154
  return False
125
155
  else:
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
- from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
4
18
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
19
+ from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
5
20
 
6
21
 
7
22
  class DumpToolFactory:
@@ -25,6 +40,8 @@ class DumpToolFactory:
25
40
 
26
41
  @staticmethod
27
42
  def create(config: DebuggerConfig):
43
+ if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
44
+ raise Exception("data_mode must be one of all, input, output.")
28
45
  tool = DumpToolFactory.tools.get(config.level)
29
46
  if not tool:
30
47
  raise Exception("Valid level is needed.")
@@ -16,13 +16,20 @@
16
16
  from mindspore import Tensor, ops, mint
17
17
  from mindspore.mint.nn import functional
18
18
  from mindspore.common._stub_tensor import StubTensor
19
+ from mindspore.communication import comm_func
19
20
 
20
21
  from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP,
21
- HOOKMintOP, HOOKMintNNFunctionalOP,
22
+ HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP,
22
23
  get_wrap_api_list, setup_hooks)
23
24
  from msprobe.core.common.utils import Const
24
25
 
25
26
 
27
+ def stub_method(method):
28
+ def wrapped_method(*args, **kwargs):
29
+ return method(*args, **kwargs)
30
+ return wrapped_method
31
+
32
+
26
33
  class ApiRegistry:
27
34
  def __init__(self):
28
35
  self.tensor_ori_attr = {}
@@ -30,6 +37,7 @@ class ApiRegistry:
30
37
  self.functional_ori_attr = {}
31
38
  self.mint_ops_ori_attr = {}
32
39
  self.mint_func_ops_ori_attr = {}
40
+ self.distributed_ori_attr = {}
33
41
  self.norm_inner_ops_ori_attr = {}
34
42
 
35
43
  self.tensor_hook_attr = {}
@@ -37,6 +45,7 @@ class ApiRegistry:
37
45
  self.functional_hook_attr = {}
38
46
  self.mint_ops_hook_attr = {}
39
47
  self.mint_func_ops_hook_attr = {}
48
+ self.distibuted_hook_attr = {}
40
49
  self.norm_inner_ops_hook_attr = {}
41
50
 
42
51
  self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
@@ -47,9 +56,13 @@ class ApiRegistry:
47
56
  if Const.SEP in api:
48
57
  sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
49
58
  sub_module = getattr(ori_api_group, sub_module_name)
50
- api_ori_attr[api] = getattr(sub_module, sub_op)
59
+ ori_api_func = getattr(sub_module, sub_op)
51
60
  else:
52
- api_ori_attr[api] = getattr(ori_api_group, api)
61
+ ori_api_func = getattr(ori_api_group, api)
62
+ if ori_api_group == StubTensor:
63
+ api_ori_attr[api] = stub_method(ori_api_func)
64
+ continue
65
+ api_ori_attr[api] = ori_api_func
53
66
 
54
67
  @staticmethod
55
68
  def set_api_attr(api_group, attr_dict):
@@ -74,6 +87,7 @@ class ApiRegistry:
74
87
  self.set_api_attr(ops, self.functional_hook_attr)
75
88
  self.set_api_attr(mint, self.mint_ops_hook_attr)
76
89
  self.set_api_attr(functional, self.mint_func_ops_hook_attr)
90
+ self.set_api_attr(comm_func, self.distibuted_hook_attr)
77
91
 
78
92
  def api_set_ori_func(self):
79
93
  self.set_api_attr(Tensor, self.tensor_ori_attr)
@@ -81,6 +95,7 @@ class ApiRegistry:
81
95
  self.set_api_attr(ops, self.functional_ori_attr)
82
96
  self.set_api_attr(mint, self.mint_ops_ori_attr)
83
97
  self.set_api_attr(functional, self.mint_func_ops_ori_attr)
98
+ self.set_api_attr(comm_func, self.distributed_ori_attr)
84
99
 
85
100
  def initialize_hook(self, hook):
86
101
  wrap_api_name = get_wrap_api_list()
@@ -89,6 +104,7 @@ class ApiRegistry:
89
104
  self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr)
90
105
  self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr)
91
106
  self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr)
107
+ self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr)
92
108
  self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
93
109
  setup_hooks(hook)
94
110
  for attr_name in dir(HOOKTensor):
@@ -113,6 +129,10 @@ class ApiRegistry:
113
129
  if attr_name.startswith(Const.ATTR_NAME_PREFIX):
114
130
  api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
115
131
  self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name)
132
+ for attr_name in dir(HOOKDistributedOP):
133
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
134
+ api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:]
135
+ self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name)
116
136
 
117
137
 
118
138
  api_register = ApiRegistry()