mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,140 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "hook_dynamic_loader.h"
18
+ #include <sys/stat.h>
19
+ #include <cstdlib>
20
+ #include <cstring>
21
+ #include "utils/log_adapter.h"
22
+
23
+ namespace {
24
+
25
+ // Utility function to check if a file path is valid
26
+ bool IsValidPath(const std::string &path) {
27
+ struct stat fileStat;
28
+ if (stat(path.c_str(), &fileStat) != 0) {
29
+ MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
+ return false;
31
+ }
32
+
33
+ if (S_ISLNK(fileStat.st_mode)) {
34
+ MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
+ return false;
36
+ }
37
+
38
+ if (!S_ISREG(fileStat.st_mode)) {
39
+ MS_LOG(ERROR) << "File is not a regular file: " << path;
40
+ return false;
41
+ }
42
+
43
+ if (path.substr(path.find_last_of(".")) != ".so") {
44
+ MS_LOG(ERROR) << "File is not a .so file: " << path;
45
+ return false;
46
+ }
47
+
48
+ return true;
49
+ }
50
+
51
+ } // namespace
52
+
53
+ HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
+ static HookDynamicLoader instance;
55
+ return instance;
56
+ }
57
+
58
+ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
59
+ void *func = dlsym(handle, functionName.c_str());
60
+ if (!func) {
61
+ MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
62
+ return false;
63
+ }
64
+ funcMap_[functionName] = func;
65
+ return true;
66
+ }
67
+
68
+ bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
+ char *realPath = realpath(libPath.c_str(), nullptr);
70
+ if (!realPath) {
71
+ MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
+ return false;
73
+ }
74
+
75
+ bool isValid = IsValidPath(realPath);
76
+ free(realPath); // Free memory allocated by realpath
77
+ return isValid;
78
+ }
79
+
80
+ bool HookDynamicLoader::LoadLibrary() {
81
+ const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
+ if (!libPath) {
83
+ MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
+ return false;
85
+ }
86
+
87
+ std::string resolvedLibPath(libPath);
88
+ if (!validateLibraryPath(resolvedLibPath)) {
89
+ MS_LOG(WARNING) << "Library path validation failed.";
90
+ return false;
91
+ }
92
+
93
+ std::lock_guard<std::mutex> lock(mutex_);
94
+ if (handle_) {
95
+ MS_LOG(WARNING) << "Hook library already loaded!";
96
+ return false;
97
+ }
98
+
99
+ handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
+ if (!handle_) {
101
+ MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
+ return false;
103
+ }
104
+
105
+ for (const auto &functionName : functionList_) {
106
+ if (!loadFunction(handle_, functionName)) {
107
+ MS_LOG(WARNING) << "Failed to load function: " << functionName;
108
+ dlclose(handle_);
109
+ handle_ = nullptr;
110
+ return false;
111
+ }
112
+ }
113
+
114
+ MS_LOG(INFO) << "Hook library loaded successfully.";
115
+ return true;
116
+ }
117
+
118
+ bool HookDynamicLoader::UnloadLibrary() {
119
+ std::lock_guard<std::mutex> lock(mutex_);
120
+ if (!handle_) {
121
+ MS_LOG(WARNING) << "Hook library hasn't been loaded.";
122
+ return false;
123
+ }
124
+
125
+ dlclose(handle_);
126
+ handle_ = nullptr;
127
+ funcMap_.clear();
128
+ MS_LOG(INFO) << "Library unloaded successfully.";
129
+ return true;
130
+ }
131
+
132
+ void *HookDynamicLoader::GetHooker(const std::string &funcName) {
133
+ std::lock_guard<std::mutex> lock(mutex_);
134
+ auto iter = funcMap_.find(funcName);
135
+ if (iter == funcMap_.end()) {
136
+ MS_LOG(WARNING) << "Function not found: " << funcName;
137
+ return nullptr;
138
+ }
139
+ return iter->second;
140
+ }
@@ -0,0 +1,53 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef HOOK_DYNAMIC_LOADER_H
18
+ #define HOOK_DYNAMIC_LOADER_H
19
+
20
+ #include <dlfcn.h>
21
+ #include <string>
22
+ #include <vector>
23
+ #include <map>
24
+ #include <mutex>
25
+
26
+ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
27
+ constexpr auto kHookEnd = "MS_DbgOnStepEnd";
28
+
29
+ class HookDynamicLoader {
30
+ public:
31
+ static HookDynamicLoader &GetInstance();
32
+
33
+ HookDynamicLoader(const HookDynamicLoader &) = delete;
34
+ HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
35
+
36
+ bool LoadLibrary();
37
+ bool UnloadLibrary();
38
+ void *GetHooker(const std::string &funcName);
39
+
40
+ private:
41
+ // Helper functions
42
+ bool loadFunction(void *handle, const std::string &functionName);
43
+ bool validateLibraryPath(const std::string &libPath);
44
+
45
+ HookDynamicLoader() = default;
46
+
47
+ void *handle_ = nullptr;
48
+ std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
49
+ std::map<std::string, void *> funcMap_;
50
+ std::mutex mutex_;
51
+ };
52
+
53
+ #endif // HOOK_DYNAMIC_LOADER_H
@@ -1,21 +1,43 @@
1
- import os
2
- import inspect
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
3
17
  import importlib
18
+ import os
19
+ import traceback
4
20
 
5
21
  import mindspore as ms
6
- from mindspore.communication import comm_func
7
22
 
8
- from msprobe.core.common.file_utils import load_yaml, check_path_length
9
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
25
+ from msprobe.core.common.file_utils import check_path_length, load_yaml
10
26
  from msprobe.mindspore.common.const import Const as MsConst
11
27
  from msprobe.mindspore.common.const import FreeBenchmarkConst
12
- from msprobe.mindspore.free_benchmark.common.config import Config
13
28
  from msprobe.mindspore.common.log import logger
29
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
14
30
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
15
- from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
31
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
32
+ from msprobe.mindspore.free_benchmark.common.config import Config
33
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
35
+ from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
36
+ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
37
+ from msprobe.mindspore.runtime import Runtime
16
38
 
17
39
 
18
- class ApiPyNativeSelFCheck:
40
+ class ApiPyNativeSelfCheck:
19
41
  def __init__(self, config: DebuggerConfig):
20
42
  Config.is_enable = True
21
43
  Config.handler_type = config.handler_type
@@ -24,29 +46,68 @@ class ApiPyNativeSelFCheck:
24
46
  Config.dump_level = config.dump_level
25
47
  Config.steps = config.step
26
48
  Config.ranks = config.rank
27
- Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
49
+ Config.dump_path = os.path.join(config.dump_path, FreeBenchmarkConst.CHECK_RESULT_FILE)
28
50
  check_path_length(Config.dump_path)
29
51
 
52
+ self.ori_func = {}
53
+
30
54
  self.api_list = config.list
31
55
  all_api = get_supported_ops()
32
56
  if not self.api_list:
33
57
  self.api_list = all_api
34
58
  else:
35
59
  self.api_list = set(self.api_list) & all_api
60
+ self.store_original_func()
36
61
 
37
62
  def handle(self):
63
+ api_register.initialize_hook(self.build_hook)
64
+ api_register.api_set_hook_func()
65
+
66
+ def build_hook(self, api_name_with_id):
67
+ def forward_hook(api_name_with_id, cell, input_data, output_data):
68
+ ret = None
69
+
70
+ if not need_wrapper_func():
71
+ del cell.input_kwargs
72
+ return ret
73
+
74
+ api_name_with_id = api_name_with_id[:-1]
75
+ hook_prefix = api_name_with_id[:api_name_with_id.find(Const.SEP) + 1]
76
+ api_name = (MsConst.HOOK_MS_PREFIX_DICT.get(hook_prefix, "") +
77
+ api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
78
+ if api_name in self.api_list:
79
+ ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
80
+ *input_data, **cell.input_kwargs)
81
+
82
+ del cell.input_kwargs
83
+ return ret
84
+
85
+ def backward_hook(cell, grad_input, grad_output):
86
+ pass
87
+
88
+ forward_hook = functools.partial(forward_hook, api_name_with_id)
89
+
90
+ def wrap_forward_hook(cell, input_data, output_data):
91
+ return forward_hook(cell, input_data, output_data)
92
+
93
+ def wrap_backward_hook(cell, grad_input, grad_output):
94
+ return backward_hook(cell, grad_input, grad_output)
95
+
96
+ return wrap_forward_hook, wrap_backward_hook
97
+
98
+ def store_original_func(self):
38
99
  for api_name in self.api_list:
39
- hijack(api_name)
100
+ self.ori_func[api_name] = get_module(api_name)[1]
40
101
 
41
102
 
42
103
  def get_supported_ops():
43
104
  supported_ops = []
44
105
  cur_path = os.path.dirname(os.path.realpath(__file__))
45
- yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
106
+ yaml_path = os.path.join(cur_path, "data", FreeBenchmarkConst.SUPPORTED_CHECK_API_FILE)
46
107
 
47
- yaml_data = load_yaml(yaml_path)
108
+ supported_ops_list = load_yaml(yaml_path)
48
109
  for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
49
- ops = yaml_data.get(k)
110
+ ops = supported_ops_list.get(k)
50
111
  if ops:
51
112
  ops = [v + i for i in ops]
52
113
  supported_ops += ops
@@ -57,7 +118,7 @@ def get_supported_ops():
57
118
  _all_functional_ops += ms_ops
58
119
 
59
120
  ms_tensor = dir(ms.Tensor)
60
- ms_tensor = [MsConst.Tensor_PREFIX + i for i in ms_tensor]
121
+ ms_tensor = [MsConst.TENSOR_PREFIX + i for i in ms_tensor]
61
122
  _all_functional_ops += ms_tensor
62
123
 
63
124
  ms_mint = dir(ms.mint)
@@ -68,29 +129,9 @@ def get_supported_ops():
68
129
  ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
69
130
  _all_functional_ops += ms_mint_nn_func
70
131
 
71
- ms_communication = dir(comm_func)
72
- ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
73
- _all_functional_ops += ms_communication
74
-
75
132
  return set(supported_ops) & set(_all_functional_ops)
76
133
 
77
134
 
78
- def get_decorate_func():
79
- return decorate_forward_function
80
-
81
-
82
- def is_func_support_decorate(orig_func):
83
- return not inspect.isclass(orig_func) and callable(orig_func)
84
-
85
-
86
- def get_wrapper_obj(orig_func, api_name):
87
- if is_func_support_decorate(orig_func):
88
- wrapped_obj = get_decorate_func()(orig_func, api_name)
89
- else:
90
- wrapped_obj = orig_func
91
- return wrapped_obj
92
-
93
-
94
135
  def get_module(api_name):
95
136
  func_name_list = api_name.split(Const.SEP)
96
137
  func_name = func_name_list[-1]
@@ -104,13 +145,93 @@ def get_module(api_name):
104
145
  return module_obj, orig_func
105
146
 
106
147
 
107
- def hijack(api_name):
108
- if not api_name.strip():
109
- return
148
+ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
149
+ ret = None
150
+
151
+ if Config.stage == Const.BACKWARD and not (check_all_tensor(args) and check_all_tensor(output)):
152
+ logger.warning(f"{api_name_with_id} has non-tensor input or output.")
153
+ return ret
154
+
155
+ params = data_pre_deal(api_name_with_id, ori_func, *args, **kwargs)
156
+ if params.index == -1:
157
+ return ret
158
+
159
+ logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
160
+ api_register.api_set_ori_func()
161
+
110
162
  try:
111
- func_name = api_name.split(Const.SEP)[-1]
112
- module_obj, origin_func = get_module(api_name)
113
- wrapped_obj = get_wrapper_obj(origin_func, api_name)
114
- setattr(module_obj, func_name, wrapped_obj)
163
+ perturbation = PerturbationFactory.create(api_name_with_id)
164
+ params.fuzzed_result = perturbation.handle(params)
165
+ if params.fuzzed_result is False:
166
+ api_register.api_set_hook_func()
167
+ return ret
168
+ if Config.stage == Const.BACKWARD:
169
+ params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
170
+ else:
171
+ params.original_result = output
172
+ ret = deal_fuzzed_and_original_result(api_name_with_id, params)
115
173
  except Exception as e:
116
- logger.error(f"Failed decorator {api_name}: {e}")
174
+ logger.error(f"[{api_name_with_id}] Error: {str(e)}")
175
+ logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
176
+
177
+ api_register.api_set_hook_func()
178
+ return ret
179
+
180
+
181
+ def check_all_tensor(input_output):
182
+ if isinstance(input_output, ms.Tensor):
183
+ return True
184
+ if isinstance(input_output, (tuple, list)):
185
+ return all([check_all_tensor(v) for v in input_output])
186
+ return False
187
+
188
+
189
+ def get_target_arg_index(args) -> int:
190
+ """
191
+ 类型校验
192
+
193
+ """
194
+ for i, arg in enumerate(args):
195
+ if ms.ops.is_tensor(arg):
196
+ if not ms.ops.is_floating_point(arg):
197
+ continue
198
+ return i
199
+ if isinstance(arg, (list, tuple, dict)):
200
+ return i
201
+ return -1
202
+
203
+
204
+ def data_pre_deal(api_name_with_id, func, *args, **kwargs):
205
+ params = HandlerParams()
206
+ params.args = args
207
+ params.kwargs = kwargs
208
+ params.original_func = func
209
+ index = get_target_arg_index(args)
210
+ if index == -1:
211
+ logger.warning(f"{api_name_with_id} has no supported input type.")
212
+ params.index = index
213
+ return params
214
+
215
+
216
+ def need_wrapper_func():
217
+ if not (Runtime.is_running and Config.is_enable):
218
+ return False
219
+
220
+ if Config.steps and Runtime.step_count not in Config.steps:
221
+ return False
222
+
223
+ if Runtime.rank_id == -1:
224
+ try:
225
+ Runtime.rank_id = get_rank_if_initialized()
226
+ except DistributedNotInitializedError:
227
+ Runtime.rank_id = -1
228
+ if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
229
+ return False
230
+
231
+ return True
232
+
233
+
234
+ def deal_fuzzed_and_original_result(api_name_with_id, params: HandlerParams):
235
+ handler = HandlerFactory.create(api_name_with_id)
236
+ result = handler.handle(params)
237
+ return result
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Optional, Any, Tuple, Dict, Callable
2
17
 
3
18
 
@@ -12,6 +27,5 @@ class HandlerParams:
12
27
  original_result: Optional[Any] = None
13
28
  fuzzed_result: Optional[Any] = None
14
29
  is_consistent: Optional[bool] = True
15
- save_flag: Optional[bool] = True
16
30
  fuzzed_value: Optional[Any] = None
17
31
  original_func: Optional[Callable] = None
@@ -1,14 +1,28 @@
1
- from typing import Any
2
- from typing import Optional
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  from dataclasses import dataclass
17
+ from typing import Any, Optional
4
18
 
5
19
  import mindspore as ms
6
- from mindspore import Tensor
20
+ from mindspore import Tensor, ops
7
21
 
8
- from msprobe.mindspore.runtime import Runtime
9
22
  from msprobe.mindspore.common.const import FreeBenchmarkConst
10
- from .config import Config
11
- from .handler_params import HandlerParams
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
24
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.runtime import Runtime
12
26
 
13
27
 
14
28
  class Tools:
@@ -29,6 +43,23 @@ class Tools:
29
43
  return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
30
44
  return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
31
45
 
46
+ @staticmethod
47
+ def get_grad_out(outputs):
48
+ if isinstance(outputs, Tensor):
49
+ return ops.ones_like(outputs)
50
+ if isinstance(outputs, (tuple, list)):
51
+ return type(outputs)([Tools.get_grad_out(v) for v in outputs])
52
+ return outputs
53
+
54
+ @staticmethod
55
+ def get_grad(func, *args, **kwargs):
56
+ def target_func(*inputs):
57
+ return func(*inputs, **kwargs)
58
+
59
+ outputs, vjp_fn = ms.vjp(target_func, *args)
60
+ values = Tools.get_grad_out(outputs)
61
+ return vjp_fn(values)
62
+
32
63
 
33
64
  @dataclass
34
65
  class UnequalRow:
@@ -59,10 +90,8 @@ def make_unequal_row(
59
90
  if isinstance(ratio, float):
60
91
  row.max_rel = ratio - 1
61
92
  original_tensor = params.original_result
62
- fuzzed_tensor = params.fuzzed_result
63
93
  if index is not None:
64
94
  original_tensor = original_tensor[index]
65
- fuzzed_tensor = fuzzed_tensor[index]
66
95
  row.output_index = index
67
96
  if isinstance(original_tensor, Tensor):
68
97
  row.dtype = original_tensor.dtype