mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,206 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # 标准库导入
17
+ import multiprocessing
18
+ from multiprocessing import Manager
19
+ import os
20
+ import signal
21
+ import sys
22
+ import time
23
+
24
+ # 第三方库导入
25
+ from mindspore import context
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+
29
+ # 本地应用/库特定导入
30
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
31
+ from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
32
+ from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
33
+ from msprobe.mindspore.common.log import logger
34
+
35
+
36
+ class MultiApiAccuracyChecker(ApiAccuracyChecker):
37
+ def __init__(self, args):
38
+ # 可以添加 MultiApiAccuracyChecker 特有的属性或方法
39
+ self.api_infos = dict()
40
+
41
+ # 使用 Manager 创建共享变量,确保进程间的同步
42
+ self.manager = Manager()
43
+ self.is_first_write = self.manager.Value('b', True) # 创建共享变量
44
+
45
+ # 初始化 DataManager 时传入共享的 is_first_write
46
+ self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path, self.is_first_write)
47
+
48
+ self.args = args # 将 args 保存为类的属性
49
+
50
+ # 初始化一个属性来存储当前的设备ID(用于日志中显示)
51
+ self.current_device_id = None
52
+
53
+ def process_on_device(self, device_id, api_infos, progress_queue):
54
+ """
55
+ 在特定设备上处理一部分API。
56
+
57
+ 参数:
58
+ device_id (int): 要使用的设备ID。
59
+ api_infos (list): 包含API名称和对应信息的元组列表。
60
+ progress_queue (multiprocessing.Queue): 用于通信进度更新的队列。
61
+ """
62
+
63
+ # 设置当前设备ID
64
+ self.current_device_id = device_id
65
+
66
+ # 设置 MindSpore context 的 device_id
67
+ context.set_context(device_id=device_id)
68
+
69
+ # 遍历当前进程分配的任务
70
+ for _, (api_name_str, api_info) in enumerate(api_infos):
71
+ logger.debug(f"Processing API: {api_name_str}, Device: {device_id}")
72
+
73
+ if not self.multi_data_manager.is_unique_api(api_name_str):
74
+ logger.debug(f"API {api_name_str} is not unique, skipping.")
75
+ progress_queue.put(1)
76
+ continue
77
+
78
+ # 处理前向
79
+ forward_output_list = self.process_forward(api_name_str, api_info)
80
+ if forward_output_list is not Const.EXCEPTION_NONE:
81
+ self.multi_data_manager.record(forward_output_list)
82
+
83
+ # 处理反向
84
+ backward_output_list = self.process_backward(api_name_str, api_info)
85
+ if backward_output_list is not Const.EXCEPTION_NONE:
86
+ self.multi_data_manager.record(backward_output_list)
87
+
88
+ # 保存结果
89
+ self.multi_data_manager.save_results(api_name_str)
90
+ progress_queue.put(1) # 更新进度
91
+
92
+ def run_and_compare(self):
93
+ # 获取要使用的设备ID列表
94
+ device_ids = self.args.device_id
95
+
96
+ # 按设备数划分要处理的 API 项
97
+ partitioned_api_infos = list(self.api_infos.items())
98
+
99
+ # 在主进程中进行交叉任务切分(基于取模的方式)
100
+ partitioned_api_infos_split = [[] for _ in range(len(device_ids))]
101
+ for idx, api_info in enumerate(partitioned_api_infos):
102
+ device_index = idx % len(device_ids) # 使用取模方法分配任务
103
+ partitioned_api_infos_split[device_index].append(api_info)
104
+
105
+ # 创建一个共享进度队列
106
+ progress_queue = multiprocessing.Queue()
107
+
108
+ # 进度条
109
+ total_tasks = len(partitioned_api_infos) # 计算总任务数
110
+ with tqdm(total=total_tasks, desc="Total Progress", ncols=100) as pbar:
111
+ # 创建多进程
112
+ processes = []
113
+ for index, device_id in enumerate(device_ids):
114
+ process = multiprocessing.Process(target=self.process_on_device,
115
+ args=(device_id, partitioned_api_infos_split[index], progress_queue))
116
+ processes.append(process)
117
+ process.start()
118
+
119
+ # 主进程更新进度条
120
+ completed_tasks = 0
121
+ while completed_tasks < total_tasks:
122
+ try:
123
+ completed_tasks += progress_queue.get(timeout=Const.PROGRESS_TIMEOUT) # 设置超时时间(秒)
124
+ pbar.update(1)
125
+ except multiprocessing.queues.Empty:
126
+ logger.error("Timeout while waiting for progress updates. Skipping remaining tasks.")
127
+ break
128
+
129
+ # 检查子进程状态
130
+ for process in processes:
131
+ if not process.is_alive():
132
+ if process.exitcode != 0:
133
+ logger.error(f"Process {process.pid} exited with code {process.exitcode}.")
134
+ total_tasks -= len(partitioned_api_infos_split[processes.index(process)])
135
+ processes.remove(process)
136
+
137
+ # 确保所有子进程完成或终止
138
+ for process in processes:
139
+ process.join(timeout=Const.PROGRESS_TIMEOUT)
140
+ if process.is_alive():
141
+ logger.error(f"Process {process.pid} did not terminate. Forcing termination.")
142
+ process.terminate()
143
+
144
+ def process_forward(self, api_name_str, api_info):
145
+ """
146
+ Overrides the parent class's process_forward method to log the device ID when exceptions occur.
147
+
148
+ Parameters:
149
+ api_name_str (str): The name of the API.
150
+ api_info (object): The API information object.
151
+
152
+ Returns:
153
+ list or None: The forward output list or None if an error occurs.
154
+ """
155
+ if not api_info.check_forward_info():
156
+ logger.debug(
157
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks forward information, skipping forward check.")
158
+ return Const.EXCEPTION_NONE
159
+
160
+ try:
161
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
162
+ except Exception as e:
163
+ logger.warning(
164
+ f"[Device {self.current_device_id}] Exception occurred while getting forward API inputs for {api_name_str}. Skipping forward check. Detailed exception information: {e}.")
165
+ return Const.EXCEPTION_NONE
166
+
167
+ forward_output_list = None
168
+ try:
169
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
170
+ Const.FORWARD)
171
+ except Exception as e:
172
+ logger.warning(
173
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} forward API. Detailed exception information: {e}.")
174
+ return forward_output_list
175
+
176
+ def process_backward(self, api_name_str, api_info):
177
+ """
178
+ Overrides the parent class's process_backward method to log the device ID when exceptions occur.
179
+
180
+ Parameters:
181
+ api_name_str (str): The name of the API.
182
+ api_info (object): The API information object.
183
+
184
+ Returns:
185
+ list or None: The backward output list or None if an error occurs.
186
+ """
187
+ if not api_info.check_backward_info():
188
+ logger.debug(
189
+ f"[Device {self.current_device_id}] API: {api_name_str} lacks backward information, skipping backward check.")
190
+ return Const.EXCEPTION_NONE
191
+
192
+ try:
193
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
194
+ except Exception as e:
195
+ logger.warning(
196
+ f"[Device {self.current_device_id}] Exception occurred while getting backward API inputs for {api_name_str}. Skipping backward check. Detailed exception information: {e}.")
197
+ return Const.EXCEPTION_NONE
198
+
199
+ backward_output_list = None
200
+ try:
201
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
202
+ Const.BACKWARD)
203
+ except Exception as e:
204
+ logger.warning(
205
+ f"[Device {self.current_device_id}] Exception occurred while running and comparing {api_name_str} backward API. Detailed exception information: {e}.")
206
+ return backward_output_list
@@ -0,0 +1,58 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import multiprocessing
18
+ import os
19
+
20
+ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header, get_result_csv_header, get_detail_csv_header, check_csv_header
21
+ from msprobe.mindspore.common.log import logger
22
+
23
+
24
+ class MultiDataManager(DataManager):
25
+ def __init__(self, csv_dir, result_csv_path, shared_is_first_write):
26
+ super().__init__(csv_dir, result_csv_path)
27
+
28
+ # 使用共享的 is_first_write 变量来控制表头写入
29
+ self.shared_is_first_write = shared_is_first_write
30
+ # 创建锁对象,确保线程安全
31
+ self.lock = multiprocessing.Lock()
32
+
33
+ def save_results(self, api_name_str):
34
+ """保存结果,线程安全操作"""
35
+
36
+ with self.lock: # 确保保存操作不会被多个进程同时进行
37
+ if self.is_first_write and self.shared_is_first_write.value:
38
+ self.shared_is_first_write.value = False
39
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
40
+ # 直接写入表头
41
+ logger.info("Writing CSV headers for the first time.")
42
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
43
+ write_csv_header(self.result_out_path, get_result_csv_header)
44
+
45
+ """写入详细输出和结果摘要并清理结果"""
46
+ self.to_detail_csv(self.detail_out_path)
47
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
48
+
49
+ self.to_result_csv(self.result_out_path)
50
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
51
+
52
+ # 清理记录,准备下一次调用
53
+ self.clear_results()
54
+
55
+ def clear_results(self):
56
+ """清空 self.results 数据,线程安全操作"""
57
+ logger.debug("Clearing results data.")
58
+ self.results.clear()
@@ -1,7 +1,23 @@
1
- from mindspore.common import dtype as mstype
2
- import numpy as np
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  import mindspore
17
+ import numpy as np
4
18
  import torch
19
+ from mindspore._c_expression import typing
20
+ from mindspore.common import dtype as mstype
5
21
 
6
22
  INT8 = "Int8"
7
23
  UINT8 = "UInt8"
@@ -18,7 +34,6 @@ BOOL = "Bool"
18
34
  BFLOAT16 = "BFloat16"
19
35
  INT4 = "Int4"
20
36
 
21
-
22
37
  dtype_str_to_ms_dtype = {
23
38
  INT8: mstype.int8,
24
39
  UINT8: mstype.uint8,
@@ -37,7 +52,6 @@ dtype_str_to_ms_dtype = {
37
52
  }
38
53
  ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
39
54
 
40
-
41
55
  dtype_str_to_np_dtype = {
42
56
  INT8: np.int8,
43
57
  UINT8: np.uint8,
@@ -75,6 +89,8 @@ FLOAT_TYPE_STR = "float"
75
89
  SLICE_TYPE_STR = "slice"
76
90
  TUPLE_TYPE_STR = "tuple"
77
91
  STR_TYPE_STR = "str"
92
+ MINDSPORE_DTYPE_TYPE_STR = "mindspore.dtype"
93
+ TORCH_DTYPE_TYPE_STR = "torch.dtype"
78
94
 
79
95
  api_info_type_str_to_type = {
80
96
  MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
@@ -83,6 +99,7 @@ api_info_type_str_to_type = {
83
99
  FLOAT_TYPE_STR: float,
84
100
  SLICE_TYPE_STR: slice,
85
101
  STR_TYPE_STR: str,
102
+ MINDSPORE_DTYPE_TYPE_STR: typing.Type,
86
103
  }
87
104
  type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
88
105
 
@@ -111,4 +128,4 @@ uint_dtype_str_list = [
111
128
  UINT16,
112
129
  UINT32,
113
130
  UINT64,
114
- ]
131
+ ]
@@ -1,8 +1,24 @@
1
- from msprobe.core.common.exceptions import ApiAccuracyCheckerException
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.core.common.const import Const
17
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
3
18
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
4
19
  from msprobe.mindspore.common.log import logger
5
20
 
21
+
6
22
  def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
7
23
  '''
8
24
  Args:
@@ -22,30 +38,30 @@ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_t
22
38
  3. value is not accepted type
23
39
  4. value is not accepted value
24
40
  '''
25
- parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
26
41
  if not isinstance(dict_instance, dict):
27
- logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
42
+ error_info = "check_and_get_from_json_dict failed: input is not a dict"
43
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
28
44
  value = dict_instance.get(key)
29
45
  if value is None:
30
- logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
31
- parse_failed_exception)
46
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
47
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
32
48
  elif accepted_type is not None and not isinstance(value, accepted_type):
33
- logger.error_log_with_exp(
34
- f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
35
- parse_failed_exception)
49
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
50
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
36
51
  elif accepted_value is not None and value not in accepted_value:
37
- logger.error_log_with_exp(
38
- f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
39
- parse_failed_exception)
52
+ error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
53
+ raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
40
54
  return value
41
55
 
42
- def convert_to_tuple(input):
43
- if isinstance(input, (tuple, list)):
44
- return tuple(input)
56
+
57
+ def convert_to_tuple(args):
58
+ if isinstance(args, (tuple, list)):
59
+ return tuple(args)
45
60
  else:
46
- input_list = [input]
61
+ input_list = [args]
47
62
  return tuple(input_list)
48
63
 
64
+
49
65
  def trim_output_compute_element_list(compute_element_list, forward_or_backward):
50
66
  '''
51
67
  Args:
@@ -55,12 +71,13 @@ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
55
71
  trimmed_list = []
56
72
  for compute_element in compute_element_list:
57
73
  if compute_element.get_parameter() is None or \
58
- (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
74
+ (forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
59
75
  # trim case: 1. parameter is None. 2. backward output has non float parameter
60
76
  continue
61
77
  trimmed_list.append(compute_element)
62
78
  return trimmed_list
63
79
 
80
+
64
81
  class GlobalContext:
65
82
  def __init__(self):
66
83
  self.is_constructed = True
@@ -77,4 +94,4 @@ class GlobalContext:
77
94
  return self.is_constructed
78
95
 
79
96
 
80
- global_context = GlobalContext()
97
+ global_context = GlobalContext()
@@ -1,4 +1,19 @@
1
- from msprobe.core.data_dump.scope import ModuleRangeScope
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
2
17
  from msprobe.core.common.const import Const
3
18
 
4
19
 
@@ -9,10 +24,7 @@ class CellProcessor:
9
24
  module_node = {}
10
25
 
11
26
  def __init__(self, scope):
12
- if isinstance(scope, ModuleRangeScope):
13
- self.scope = scope
14
- else:
15
- self.scope = None
27
+ self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
16
28
 
17
29
  @staticmethod
18
30
  def set_cell_count(cell_name):
@@ -21,30 +33,29 @@ class CellProcessor:
21
33
  else:
22
34
  CellProcessor.cell_count[cell_name] += 1
23
35
  return CellProcessor.cell_count[cell_name]
24
-
36
+
25
37
  @classmethod
26
38
  def reset_cell_stats(cls):
27
39
  cls.cell_count = {}
28
40
  cls.cell_stack = []
29
41
  cls.api_parent_node = ""
30
42
  cls.module_node = {}
31
-
43
+
32
44
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
33
- def begin_hook(cell, input):
34
- index = self.set_cell_count(name_prefix)
35
- cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
45
+ def begin_hook(cell, input_data):
46
+ full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
36
47
  if CellProcessor.cell_stack:
37
48
  CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
38
49
  else:
39
50
  CellProcessor.module_node[full_name] = None
40
-
51
+
41
52
  CellProcessor.cell_stack.append(full_name)
42
53
  CellProcessor.api_parent_node = full_name
43
54
 
44
55
  if self.scope:
45
56
  self.scope.begin_module(full_name)
46
57
 
47
- def end_hook(cell, input, output):
58
+ def end_hook(cell, input_data, output_data):
48
59
  if CellProcessor.cell_stack:
49
60
  CellProcessor.cell_stack.pop()
50
61
  if CellProcessor.cell_stack:
@@ -56,3 +67,13 @@ class CellProcessor:
56
67
  self.scope.end_module(cell.mindstudio_reserved_name)
57
68
 
58
69
  return begin_hook if Const.START == start_or_stop else end_hook
70
+
71
+ def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
72
+ if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
73
+ cell.has_pre_hook_called = False
74
+ else:
75
+ if is_called_by_pre_hook:
76
+ cell.has_pre_hook_called = True
77
+ index = self.set_cell_count(cell_name)
78
+ cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
79
+ return cell.mindstudio_reserved_name
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
  import mindspore as ms
3
18
 
@@ -23,17 +38,10 @@ class Const:
23
38
  ASCEND_910A = "ascend910"
24
39
 
25
40
  OPS_PREFIX = "mindspore.ops."
26
- Tensor_PREFIX = "mindspore.Tensor."
41
+ TENSOR_PREFIX = "mindspore.Tensor."
27
42
  MINT_PREFIX = "mindspore.mint."
28
43
  MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
29
- COMM_PREFIX = "mindspore.communication.comm_func."
30
- COMMUNICATION_API_LIST = [
31
- "mindspore.communication.comm_func.all_gather_into_tensor",
32
- "mindspore.communication.comm_func.gather_into_tensor",
33
- "mindspore.communication.comm_func.all_reduce",
34
- "mindspore.communication.comm_func.reduce",
35
- "mindspore.communication.comm_func.reduce_scatter_tensor"
36
- ]
44
+
37
45
  TENSOR_DATA_PREFIX = "Tensor."
38
46
  STUB_TENSOR_DATA_PREFIX = "Tensor."
39
47
  OPS_DATA_PREFIX = "Functional."
@@ -50,6 +58,15 @@ class Const:
50
58
 
51
59
  DROPOUT_API_NAME_PREFIX = "dropout"
52
60
 
61
+ GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
62
+
63
+ HOOK_MS_PREFIX_DICT = {
64
+ OPS_DATA_PREFIX: OPS_PREFIX,
65
+ TENSOR_DATA_PREFIX: TENSOR_PREFIX,
66
+ MINT_DATA_PREFIX: MINT_PREFIX,
67
+ MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
68
+ }
69
+
53
70
 
54
71
  class FreeBenchmarkConst:
55
72
  ADD_NOISE = "add_noise"
@@ -65,19 +82,21 @@ class FreeBenchmarkConst:
65
82
  DEFAULT_PERT_TYPE = IMPROVE_PRECISION
66
83
  DEFAULT_HANDLER_TYPE = CHECK
67
84
  DEVICE_LIST = [DEFAULT_DEVICE]
68
- STAGE_LIST = [CoreConst.FORWARD]
85
+ STAGE_LIST = [CoreConst.FORWARD, CoreConst.BACKWARD]
69
86
  DUMP_LEVEL_LIST = [DEFAULT_DUMP_LEVEL]
70
87
  PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE, EXCHANGE_VALUE]
71
88
  HANDLER_TYPE_LIST = [CHECK, FIX]
72
89
  NO_CHANGE_ERROR_THRESHOLD = 1.0
73
90
  SYMBOL_FLIPPING_RATIO = 8.0
74
91
 
92
+ SUPPORTED_CHECK_API_FILE = "support_wrap_ops.yaml"
93
+ CHECK_RESULT_FILE = "free_benchmark.csv"
94
+
75
95
  API_PREFIX_DICT = {
76
96
  "ops": Const.OPS_PREFIX,
77
- "Tensor": Const.Tensor_PREFIX,
97
+ "Tensor": Const.TENSOR_PREFIX,
78
98
  "mint": Const.MINT_PREFIX,
79
- "mint.nn.functional": Const.MINT_NN_FUNC_PREFIX,
80
- "communication": Const.COMM_PREFIX
99
+ "mint.nn.functional": Const.MINT_NN_FUNC_PREFIX
81
100
  }
82
101
 
83
102
  PERT_VALUE_DICT = {
@@ -88,6 +107,7 @@ class FreeBenchmarkConst:
88
107
  }
89
108
 
90
109
  ERROR_THRESHOLD = {
110
+ ms.bfloat16: 1.004,
91
111
  ms.float16: 1.002,
92
112
  ms.float32: 1.0002
93
113
  }
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,15 +12,10 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
- import os
17
- import time
18
- import sys
19
-
20
- from msprobe.mindspore.common.utils import get_rank_if_initialized
21
- from msprobe.core.common.log import BaseLogger
22
16
  from msprobe.core.common.exceptions import DistributedNotInitializedError
17
+ from msprobe.core.common.log import BaseLogger
18
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
23
19
 
24
20
 
25
21
  class MindsporeLogger(BaseLogger):
@@ -35,4 +31,4 @@ class MindsporeLogger(BaseLogger):
35
31
  return current_rank
36
32
 
37
33
 
38
- logger = MindsporeLogger()
34
+ logger = MindsporeLogger()