mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,140 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "hook_dynamic_loader.h"
18
+ #include <sys/stat.h>
19
+ #include <cstdlib>
20
+ #include <cstring>
21
+ #include "utils/log_adapter.h"
22
+
23
+ namespace {
24
+
25
+ // Utility function to check if a file path is valid
26
+ bool IsValidPath(const std::string &path) {
27
+ struct stat fileStat;
28
+ if (stat(path.c_str(), &fileStat) != 0) {
29
+ MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
+ return false;
31
+ }
32
+
33
+ if (S_ISLNK(fileStat.st_mode)) {
34
+ MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
+ return false;
36
+ }
37
+
38
+ if (!S_ISREG(fileStat.st_mode)) {
39
+ MS_LOG(ERROR) << "File is not a regular file: " << path;
40
+ return false;
41
+ }
42
+
43
+ if (path.substr(path.find_last_of(".")) != ".so") {
44
+ MS_LOG(ERROR) << "File is not a .so file: " << path;
45
+ return false;
46
+ }
47
+
48
+ return true;
49
+ }
50
+
51
+ } // namespace
52
+
53
+ HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
+ static HookDynamicLoader instance;
55
+ return instance;
56
+ }
57
+
58
+ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
59
+ void *func = dlsym(handle, functionName.c_str());
60
+ if (!func) {
61
+ MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
62
+ return false;
63
+ }
64
+ funcMap_[functionName] = func;
65
+ return true;
66
+ }
67
+
68
+ bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
+ char *realPath = realpath(libPath.c_str(), nullptr);
70
+ if (!realPath) {
71
+ MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
+ return false;
73
+ }
74
+
75
+ bool isValid = IsValidPath(realPath);
76
+ free(realPath); // Free memory allocated by realpath
77
+ return isValid;
78
+ }
79
+
80
+ bool HookDynamicLoader::LoadLibrary() {
81
+ const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
+ if (!libPath) {
83
+ MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
+ return false;
85
+ }
86
+
87
+ std::string resolvedLibPath(libPath);
88
+ if (!validateLibraryPath(resolvedLibPath)) {
89
+ MS_LOG(WARNING) << "Library path validation failed.";
90
+ return false;
91
+ }
92
+
93
+ std::lock_guard<std::mutex> lock(mutex_);
94
+ if (handle_) {
95
+ MS_LOG(WARNING) << "Hook library already loaded!";
96
+ return false;
97
+ }
98
+
99
+ handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
+ if (!handle_) {
101
+ MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
+ return false;
103
+ }
104
+
105
+ for (const auto &functionName : functionList_) {
106
+ if (!loadFunction(handle_, functionName)) {
107
+ MS_LOG(WARNING) << "Failed to load function: " << functionName;
108
+ dlclose(handle_);
109
+ handle_ = nullptr;
110
+ return false;
111
+ }
112
+ }
113
+
114
+ MS_LOG(INFO) << "Hook library loaded successfully.";
115
+ return true;
116
+ }
117
+
118
+ bool HookDynamicLoader::UnloadLibrary() {
119
+ std::lock_guard<std::mutex> lock(mutex_);
120
+ if (!handle_) {
121
+ MS_LOG(WARNING) << "Hook library hasn't been loaded.";
122
+ return false;
123
+ }
124
+
125
+ dlclose(handle_);
126
+ handle_ = nullptr;
127
+ funcMap_.clear();
128
+ MS_LOG(INFO) << "Library unloaded successfully.";
129
+ return true;
130
+ }
131
+
132
+ void *HookDynamicLoader::GetHooker(const std::string &funcName) {
133
+ std::lock_guard<std::mutex> lock(mutex_);
134
+ auto iter = funcMap_.find(funcName);
135
+ if (iter == funcMap_.end()) {
136
+ MS_LOG(WARNING) << "Function not found: " << funcName;
137
+ return nullptr;
138
+ }
139
+ return iter->second;
140
+ }
@@ -0,0 +1,53 @@
1
+ /**
2
+ * Copyright 2024 Huawei Technologies Co., Ltd
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef HOOK_DYNAMIC_LOADER_H
18
+ #define HOOK_DYNAMIC_LOADER_H
19
+
20
+ #include <dlfcn.h>
21
+ #include <string>
22
+ #include <vector>
23
+ #include <map>
24
+ #include <mutex>
25
+
26
+ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
27
+ constexpr auto kHookEnd = "MS_DbgOnStepEnd";
28
+
29
+ class HookDynamicLoader {
30
+ public:
31
+ static HookDynamicLoader &GetInstance();
32
+
33
+ HookDynamicLoader(const HookDynamicLoader &) = delete;
34
+ HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
35
+
36
+ bool LoadLibrary();
37
+ bool UnloadLibrary();
38
+ void *GetHooker(const std::string &funcName);
39
+
40
+ private:
41
+ // Helper functions
42
+ bool loadFunction(void *handle, const std::string &functionName);
43
+ bool validateLibraryPath(const std::string &libPath);
44
+
45
+ HookDynamicLoader() = default;
46
+
47
+ void *handle_ = nullptr;
48
+ std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
49
+ std::map<std::string, void *> funcMap_;
50
+ std::mutex mutex_;
51
+ };
52
+
53
+ #endif // HOOK_DYNAMIC_LOADER_H
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -13,24 +13,31 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import functools
16
17
  import importlib
17
- import inspect
18
18
  import os
19
+ import traceback
19
20
 
20
21
  import mindspore as ms
21
- from mindspore.communication import comm_func
22
22
 
23
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
24
25
  from msprobe.core.common.file_utils import check_path_length, load_yaml
25
26
  from msprobe.mindspore.common.const import Const as MsConst
26
27
  from msprobe.mindspore.common.const import FreeBenchmarkConst
27
28
  from msprobe.mindspore.common.log import logger
29
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
28
30
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
31
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
29
32
  from msprobe.mindspore.free_benchmark.common.config import Config
30
- from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function
33
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
35
+ from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
36
+ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
37
+ from msprobe.mindspore.runtime import Runtime
31
38
 
32
39
 
33
- class ApiPyNativeSelFCheck:
40
+ class ApiPyNativeSelfCheck:
34
41
  def __init__(self, config: DebuggerConfig):
35
42
  Config.is_enable = True
36
43
  Config.handler_type = config.handler_type
@@ -39,29 +46,68 @@ class ApiPyNativeSelFCheck:
39
46
  Config.dump_level = config.dump_level
40
47
  Config.steps = config.step
41
48
  Config.ranks = config.rank
42
- Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv")
49
+ Config.dump_path = os.path.join(config.dump_path, FreeBenchmarkConst.CHECK_RESULT_FILE)
43
50
  check_path_length(Config.dump_path)
44
51
 
52
+ self.ori_func = {}
53
+
45
54
  self.api_list = config.list
46
55
  all_api = get_supported_ops()
47
56
  if not self.api_list:
48
57
  self.api_list = all_api
49
58
  else:
50
59
  self.api_list = set(self.api_list) & all_api
60
+ self.store_original_func()
51
61
 
52
62
  def handle(self):
63
+ api_register.initialize_hook(self.build_hook)
64
+ api_register.api_set_hook_func()
65
+
66
+ def build_hook(self, api_name_with_id):
67
+ def forward_hook(api_name_with_id, cell, input_data, output_data):
68
+ ret = None
69
+
70
+ if not need_wrapper_func():
71
+ del cell.input_kwargs
72
+ return ret
73
+
74
+ api_name_with_id = api_name_with_id[:-1]
75
+ hook_prefix = api_name_with_id[:api_name_with_id.find(Const.SEP) + 1]
76
+ api_name = (MsConst.HOOK_MS_PREFIX_DICT.get(hook_prefix, "") +
77
+ api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
78
+ if api_name in self.api_list:
79
+ ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
80
+ *input_data, **cell.input_kwargs)
81
+
82
+ del cell.input_kwargs
83
+ return ret
84
+
85
+ def backward_hook(cell, grad_input, grad_output):
86
+ pass
87
+
88
+ forward_hook = functools.partial(forward_hook, api_name_with_id)
89
+
90
+ def wrap_forward_hook(cell, input_data, output_data):
91
+ return forward_hook(cell, input_data, output_data)
92
+
93
+ def wrap_backward_hook(cell, grad_input, grad_output):
94
+ return backward_hook(cell, grad_input, grad_output)
95
+
96
+ return wrap_forward_hook, wrap_backward_hook
97
+
98
+ def store_original_func(self):
53
99
  for api_name in self.api_list:
54
- hijack(api_name)
100
+ self.ori_func[api_name] = get_module(api_name)[1]
55
101
 
56
102
 
57
103
  def get_supported_ops():
58
104
  supported_ops = []
59
105
  cur_path = os.path.dirname(os.path.realpath(__file__))
60
- yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml")
106
+ yaml_path = os.path.join(cur_path, "data", FreeBenchmarkConst.SUPPORTED_CHECK_API_FILE)
61
107
 
62
- yaml_data = load_yaml(yaml_path)
108
+ supported_ops_list = load_yaml(yaml_path)
63
109
  for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items():
64
- ops = yaml_data.get(k)
110
+ ops = supported_ops_list.get(k)
65
111
  if ops:
66
112
  ops = [v + i for i in ops]
67
113
  supported_ops += ops
@@ -72,7 +118,7 @@ def get_supported_ops():
72
118
  _all_functional_ops += ms_ops
73
119
 
74
120
  ms_tensor = dir(ms.Tensor)
75
- ms_tensor = [MsConst.Tensor_PREFIX + i for i in ms_tensor]
121
+ ms_tensor = [MsConst.TENSOR_PREFIX + i for i in ms_tensor]
76
122
  _all_functional_ops += ms_tensor
77
123
 
78
124
  ms_mint = dir(ms.mint)
@@ -83,29 +129,9 @@ def get_supported_ops():
83
129
  ms_mint_nn_func = [MsConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func]
84
130
  _all_functional_ops += ms_mint_nn_func
85
131
 
86
- ms_communication = dir(comm_func)
87
- ms_communication = [MsConst.COMM_PREFIX + i for i in ms_communication]
88
- _all_functional_ops += ms_communication
89
-
90
132
  return set(supported_ops) & set(_all_functional_ops)
91
133
 
92
134
 
93
- def get_decorate_func():
94
- return decorate_forward_function
95
-
96
-
97
- def is_func_support_decorate(orig_func):
98
- return not inspect.isclass(orig_func) and callable(orig_func)
99
-
100
-
101
- def get_wrapper_obj(orig_func, api_name):
102
- if is_func_support_decorate(orig_func):
103
- wrapped_obj = get_decorate_func()(orig_func, api_name)
104
- else:
105
- wrapped_obj = orig_func
106
- return wrapped_obj
107
-
108
-
109
135
  def get_module(api_name):
110
136
  func_name_list = api_name.split(Const.SEP)
111
137
  func_name = func_name_list[-1]
@@ -119,13 +145,93 @@ def get_module(api_name):
119
145
  return module_obj, orig_func
120
146
 
121
147
 
122
- def hijack(api_name):
123
- if not api_name.strip():
124
- return
148
+ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
149
+ ret = None
150
+
151
+ if Config.stage == Const.BACKWARD and not (check_all_tensor(args) and check_all_tensor(output)):
152
+ logger.warning(f"{api_name_with_id} has non-tensor input or output.")
153
+ return ret
154
+
155
+ params = data_pre_deal(api_name_with_id, ori_func, *args, **kwargs)
156
+ if params.index == -1:
157
+ return ret
158
+
159
+ logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
160
+ api_register.api_set_ori_func()
161
+
125
162
  try:
126
- func_name = api_name.split(Const.SEP)[-1]
127
- module_obj, origin_func = get_module(api_name)
128
- wrapped_obj = get_wrapper_obj(origin_func, api_name)
129
- setattr(module_obj, func_name, wrapped_obj)
163
+ perturbation = PerturbationFactory.create(api_name_with_id)
164
+ params.fuzzed_result = perturbation.handle(params)
165
+ if params.fuzzed_result is False:
166
+ api_register.api_set_hook_func()
167
+ return ret
168
+ if Config.stage == Const.BACKWARD:
169
+ params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
170
+ else:
171
+ params.original_result = output
172
+ ret = deal_fuzzed_and_original_result(api_name_with_id, params)
130
173
  except Exception as e:
131
- logger.error(f"Failed decorator {api_name}: {e}")
174
+ logger.error(f"[{api_name_with_id}] Error: {str(e)}")
175
+ logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
176
+
177
+ api_register.api_set_hook_func()
178
+ return ret
179
+
180
+
181
+ def check_all_tensor(input_output):
182
+ if isinstance(input_output, ms.Tensor):
183
+ return True
184
+ if isinstance(input_output, (tuple, list)):
185
+ return all([check_all_tensor(v) for v in input_output])
186
+ return False
187
+
188
+
189
+ def get_target_arg_index(args) -> int:
190
+ """
191
+ 类型校验
192
+
193
+ """
194
+ for i, arg in enumerate(args):
195
+ if ms.ops.is_tensor(arg):
196
+ if not ms.ops.is_floating_point(arg):
197
+ continue
198
+ return i
199
+ if isinstance(arg, (list, tuple, dict)):
200
+ return i
201
+ return -1
202
+
203
+
204
+ def data_pre_deal(api_name_with_id, func, *args, **kwargs):
205
+ params = HandlerParams()
206
+ params.args = args
207
+ params.kwargs = kwargs
208
+ params.original_func = func
209
+ index = get_target_arg_index(args)
210
+ if index == -1:
211
+ logger.warning(f"{api_name_with_id} has no supported input type.")
212
+ params.index = index
213
+ return params
214
+
215
+
216
+ def need_wrapper_func():
217
+ if not (Runtime.is_running and Config.is_enable):
218
+ return False
219
+
220
+ if Config.steps and Runtime.step_count not in Config.steps:
221
+ return False
222
+
223
+ if Runtime.rank_id == -1:
224
+ try:
225
+ Runtime.rank_id = get_rank_if_initialized()
226
+ except DistributedNotInitializedError:
227
+ Runtime.rank_id = -1
228
+ if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks:
229
+ return False
230
+
231
+ return True
232
+
233
+
234
+ def deal_fuzzed_and_original_result(api_name_with_id, params: HandlerParams):
235
+ handler = HandlerFactory.create(api_name_with_id)
236
+ result = handler.handle(params)
237
+ return result
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -27,6 +27,5 @@ class HandlerParams:
27
27
  original_result: Optional[Any] = None
28
28
  fuzzed_result: Optional[Any] = None
29
29
  is_consistent: Optional[bool] = True
30
- save_flag: Optional[bool] = True
31
30
  fuzzed_value: Optional[Any] = None
32
31
  original_func: Optional[Callable] = None
@@ -1,7 +1,7 @@
1
1
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
6
6
  # You may obtain a copy of the License at
7
7
  #
@@ -17,7 +17,7 @@ from dataclasses import dataclass
17
17
  from typing import Any, Optional
18
18
 
19
19
  import mindspore as ms
20
- from mindspore import Tensor
20
+ from mindspore import Tensor, ops
21
21
 
22
22
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
23
  from msprobe.mindspore.free_benchmark.common.config import Config
@@ -43,6 +43,23 @@ class Tools:
43
43
  return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD
44
44
  return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32))
45
45
 
46
+ @staticmethod
47
+ def get_grad_out(outputs):
48
+ if isinstance(outputs, Tensor):
49
+ return ops.ones_like(outputs)
50
+ if isinstance(outputs, (tuple, list)):
51
+ return type(outputs)([Tools.get_grad_out(v) for v in outputs])
52
+ return outputs
53
+
54
+ @staticmethod
55
+ def get_grad(func, *args, **kwargs):
56
+ def target_func(*inputs):
57
+ return func(*inputs, **kwargs)
58
+
59
+ outputs, vjp_fn = ms.vjp(target_func, *args)
60
+ values = Tools.get_grad_out(outputs)
61
+ return vjp_fn(values)
62
+
46
63
 
47
64
  @dataclass
48
65
  class UnequalRow:
@@ -73,10 +90,8 @@ def make_unequal_row(
73
90
  if isinstance(ratio, float):
74
91
  row.max_rel = ratio - 1
75
92
  original_tensor = params.original_result
76
- fuzzed_tensor = params.fuzzed_result
77
93
  if index is not None:
78
94
  original_tensor = original_tensor[index]
79
- fuzzed_tensor = fuzzed_tensor[index]
80
95
  row.output_index = index
81
96
  if isinstance(original_tensor, Tensor):
82
97
  row.dtype = original_tensor.dtype