mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,10 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
16
  from collections import defaultdict
17
+ import os
18
+ import types
18
19
 
19
20
  import mindspore
21
+ from mindspore import nn
20
22
  from mindspore._c_expression import PyNativeExecutor_
21
23
  try:
22
24
  from mindspore.common.api import _MindsporeFunctionExecutor
@@ -25,7 +27,9 @@ except ImportError:
25
27
 
26
28
  from msprobe.core.common.log import logger
27
29
  from msprobe.core.common.const import Const
30
+ from msprobe.core.common.runtime import Runtime
28
31
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
32
+ from msprobe.mindspore.common.const import Const as MsConst
29
33
  from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
30
34
 
31
35
 
@@ -34,24 +38,20 @@ _api_register = get_api_register()
34
38
 
35
39
  def dump_jit(name, in_feat, out_feat, is_forward):
36
40
  pid = os.getpid()
37
- ori_args = str(name)
38
- index = ori_args.find("<")
39
- if index != 0 and index != -1:
40
- result = ori_args[0:index]
41
- elif name is not None and "<" not in str(name):
42
- result = str(name)
43
- else:
44
- result = "JitFunction"
41
+ name = name if name else "JitFunction"
45
42
  if JitDump.need_dump():
46
43
  if is_forward:
47
- JitDump.jit_count[result] += 1
48
- name_template = (Const.JIT + Const.SEP + result + Const.SEP +
49
- str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD)
44
+ if name in JitDump.jit_count:
45
+ JitDump.jit_count[name] += 1
46
+ else:
47
+ JitDump.jit_count[name] = 0
48
+ name_template = (Const.JIT + Const.SEP + name + Const.SEP +
49
+ str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD)
50
50
  JitDump.data_collector.update_api_or_module_name(name_template)
51
51
  module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
52
52
  JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
53
53
  else:
54
- name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
54
+ name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \
55
55
  Const.BACKWARD
56
56
  JitDump.data_collector.update_api_or_module_name(name_template)
57
57
  module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
@@ -74,11 +74,11 @@ class JitDump(_MindsporeFunctionExecutor):
74
74
  def __call__(self, *args, **kwargs):
75
75
  _api_register.restore_all_api()
76
76
  out = super().__call__(*args, **kwargs)
77
- if JitDump.jit_dump_switch and len(args) > 0:
78
- if self.name and self.name != "construct":
77
+ if JitDump.jit_dump_switch and len(args) > 0 and self.name:
78
+ if self.name != "construct":
79
79
  dump_jit(self.name, args, out, True)
80
- else:
81
- dump_jit(args[0], args, out, True)
80
+ elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell):
81
+ dump_jit(args[0].__class__.__name__, args, out, True)
82
82
  JitDump.jit_enable = True
83
83
  elif len(args) == 0:
84
84
  logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
@@ -109,6 +109,9 @@ class JitDump(_MindsporeFunctionExecutor):
109
109
  else:
110
110
  output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
111
111
  if JitDump.jit_dump_switch and JitDump.jit_enable:
112
- dump_jit(obj, args, None, False)
112
+ if isinstance(obj, types.FunctionType):
113
+ dump_jit(obj.__name__, args, None, False)
114
+ elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell):
115
+ dump_jit(obj.__class__.__name__, args, None, False)
113
116
  _api_register.register_all_api()
114
117
  return output
@@ -39,9 +39,12 @@ class KernelKbykDump:
39
39
  common_set["input_output"] = 0
40
40
  common_set["kernels"] = []
41
41
  common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
42
- e2e_set = dict()
43
- e2e_set["enable"] = True
44
- e2e_set["trans_flag"] = True
42
+ e2e_set = {
43
+ "enable": not config.async_dump,
44
+ "trans_flag": True,
45
+ "stat_calc_mode": config.stat_cal_mode,
46
+ "device_stat_precision_mode": config.device_stat_precision_mode,
47
+ }
45
48
 
46
49
  if config.list:
47
50
  common_set["dump_mode"] = 1
@@ -0,0 +1,110 @@
1
+ /*
2
+ * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include "hook_dynamic_loader.h"
18
+ #include <sys/stat.h>
19
+ #include <cstdlib>
20
+ #include <cstring>
21
+ #include <pybind11/embed.h>
22
+ #include "utils/log_adapter.h"
23
+
24
+ namespace py = pybind11;
25
+
26
+ HookDynamicLoader &HookDynamicLoader::GetInstance()
27
+ {
28
+ static HookDynamicLoader instance;
29
+ return instance;
30
+ }
31
+
32
+ bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) {
33
+ void *func = dlsym(handle, functionName.c_str());
34
+ if (!func) {
35
+ MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
36
+ return false;
37
+ }
38
+ funcMap_[functionName] = func;
39
+ return true;
40
+ }
41
+
42
+ bool HookDynamicLoader::LoadLibrary()
43
+ {
44
+ std::string msprobePath = "";
45
+ // 获取gil锁
46
+ py::gil_scoped_acquire acquire;
47
+ try {
48
+ py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
49
+ if (!py::hasattr(msprobeMod, "__file__")) {
50
+ MS_LOG(WARNING) << "Adump mod not found";
51
+ return false;
52
+ }
53
+ msprobePath = msprobeMod.attr("__file__").cast<std::string>();
54
+ } catch (const std::exception& e) {
55
+ MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
56
+ return false;
57
+ }
58
+ std::lock_guard<std::mutex> lock(mutex_);
59
+ if (handle_) {
60
+ MS_LOG(WARNING) << "Hook library already loaded!";
61
+ return false;
62
+ }
63
+ if (msprobePath == "") {
64
+ MS_LOG(WARNING) << "Adump path not loaded";
65
+ return false;
66
+ }
67
+ handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
68
+ if (!handle_) {
69
+ MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
70
+ return false;
71
+ }
72
+
73
+ for (const auto &functionName : functionList_) {
74
+ if (!LoadFunction(handle_, functionName)) {
75
+ MS_LOG(WARNING) << "Failed to load adump function";
76
+ dlclose(handle_);
77
+ handle_ = nullptr;
78
+ return false;
79
+ }
80
+ }
81
+
82
+ MS_LOG(INFO) << "Hook library loaded successfully.";
83
+ return true;
84
+ }
85
+
86
+ bool HookDynamicLoader::UnloadLibrary()
87
+ {
88
+ std::lock_guard<std::mutex> lock(mutex_);
89
+ if (!handle_) {
90
+ MS_LOG(WARNING) << "Hook library hasn't been loaded.";
91
+ return false;
92
+ }
93
+
94
+ dlclose(handle_);
95
+ handle_ = nullptr;
96
+ funcMap_.clear();
97
+ MS_LOG(INFO) << "Library unloaded successfully.";
98
+ return true;
99
+ }
100
+
101
+ void *HookDynamicLoader::GetHooker(const std::string &funcName)
102
+ {
103
+ std::lock_guard<std::mutex> lock(mutex_);
104
+ auto iter = funcMap_.find(funcName);
105
+ if (iter == funcMap_.end()) {
106
+ MS_LOG(WARNING) << "Function not found: " << funcName;
107
+ return nullptr;
108
+ }
109
+ return iter->second;
110
+ }
@@ -27,26 +27,26 @@ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
27
27
  constexpr auto kHookEnd = "MS_DbgOnStepEnd";
28
28
 
29
29
  class HookDynamicLoader {
30
- public:
31
- static HookDynamicLoader &GetInstance();
30
+ public:
31
+ static HookDynamicLoader &GetInstance();
32
32
 
33
- HookDynamicLoader(const HookDynamicLoader &) = delete;
34
- HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
33
+ HookDynamicLoader(const HookDynamicLoader &) = delete;
34
+ HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
35
35
 
36
- bool LoadLibrary();
37
- bool UnloadLibrary();
38
- void *GetHooker(const std::string &funcName);
36
+ bool LoadLibrary();
37
+ bool UnloadLibrary();
38
+ void *GetHooker(const std::string &funcName);
39
39
 
40
- private:
41
- // Helper functions
42
- bool loadFunction(void *handle, const std::string &functionName);
40
+ private:
41
+ // Helper functions
42
+ bool LoadFunction(void *handle, const std::string &functionName);
43
43
 
44
- HookDynamicLoader() = default;
44
+ HookDynamicLoader() = default;
45
45
 
46
- void *handle_ = nullptr;
47
- std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
48
- std::map<std::string, void *> funcMap_;
49
- std::mutex mutex_;
46
+ void *handle_ = nullptr;
47
+ std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
48
+ std::map<std::string, void *> funcMap_;
49
+ std::mutex mutex_;
50
50
  };
51
51
 
52
52
  #endif // HOOK_DYNAMIC_LOADER_H
@@ -23,6 +23,8 @@ import mindspore as ms
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.exceptions import DistributedNotInitializedError
25
25
  from msprobe.core.common.file_utils import check_path_length, load_yaml
26
+ from msprobe.core.common.runtime import Runtime
27
+ from msprobe.core.hook_manager import HookSet
26
28
  from msprobe.mindspore.common.const import Const as MsConst
27
29
  from msprobe.mindspore.common.const import FreeBenchmarkConst
28
30
  from msprobe.mindspore.common.log import logger
@@ -35,7 +37,6 @@ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
35
37
  from msprobe.mindspore.free_benchmark.common.utils import Tools
36
38
  from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
37
39
  from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
38
- from msprobe.mindspore.runtime import Runtime
39
40
 
40
41
 
41
42
  _api_register = get_api_register()
@@ -75,7 +76,7 @@ class ApiPyNativeSelfCheck:
75
76
  ret = None
76
77
 
77
78
  if not need_wrapper_func():
78
- del cell.input_kwargs
79
+ del cell.msprobe_input_kwargs
79
80
  return ret
80
81
 
81
82
  api_name_with_id = api_name_with_id[:-1]
@@ -84,9 +85,9 @@ class ApiPyNativeSelfCheck:
84
85
  api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
85
86
  if api_name in self.api_list:
86
87
  ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
87
- *input_data, **cell.input_kwargs)
88
+ *input_data, **cell.msprobe_input_kwargs)
88
89
 
89
- del cell.input_kwargs
90
+ del cell.msprobe_input_kwargs
90
91
  return ret
91
92
 
92
93
  def backward_hook(cell, grad_input, grad_output):
@@ -105,8 +106,13 @@ class ApiPyNativeSelfCheck:
105
106
 
106
107
  def pre_backward_hook(cell, grad_input):
107
108
  return None
108
-
109
- return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
109
+
110
+ return HookSet(
111
+ forward_hook=wrap_forward_hook,
112
+ forward_pre_hook=pre_hook,
113
+ backward_hook=wrap_backward_hook,
114
+ backward_pre_hook=pre_backward_hook
115
+ )
110
116
 
111
117
  def store_original_func(self):
112
118
  for api_name in self.api_list:
@@ -19,10 +19,10 @@ from typing import Any, Optional
19
19
  import mindspore as ms
20
20
  from mindspore import Tensor, ops
21
21
 
22
+ from msprobe.core.common.runtime import Runtime
22
23
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
24
  from msprobe.mindspore.free_benchmark.common.config import Config
24
25
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
- from msprobe.mindspore.runtime import Runtime
26
26
 
27
27
 
28
28
  class Tools:
@@ -41,8 +41,12 @@ class GlobalContext:
41
41
  def __new__(cls, *args, **kwargs):
42
42
  if cls._instance is None:
43
43
  cls._instance_lock.acquire()
44
- cls._instance = object.__new__(cls)
45
- cls._instance_lock.release()
44
+ try:
45
+ cls._instance = object.__new__(cls)
46
+ except Exception as e:
47
+ raise RuntimeError("grad_probe global context init failed") from e
48
+ finally:
49
+ cls._instance_lock.release()
46
50
  return cls._instance
47
51
 
48
52
  def init_context(self, config_dict: Dict):
@@ -69,6 +73,7 @@ class GlobalContext:
69
73
  create_directory(self._setting.get(GradConst.OUTPUT_PATH))
70
74
  else:
71
75
  logger.warning("The output_path exists, the data will be covered.")
76
+
72
77
  self._setting[GradConst.TIME_STAMP] = str(int(time.time()))
73
78
 
74
79
  def get_context(self, key: str):
@@ -15,6 +15,7 @@
15
15
 
16
16
  import hashlib
17
17
  from abc import ABC, abstractmethod
18
+ import zlib
18
19
 
19
20
  import mindspore
20
21
  from mindspore import ops
@@ -76,8 +77,8 @@ class CsvMd5(CsvItem):
76
77
  def generate_csv_content(csv_input):
77
78
  grad = csv_input.grad
78
79
  tensor_bytes = grad.float().numpy().tobytes()
79
- md5_hash = hashlib.md5(tensor_bytes)
80
- return [md5_hash.hexdigest()]
80
+ md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
81
+ return [md5_hash]
81
82
 
82
83
 
83
84
  @register_csv_item(GradConst.DISTRIBUTION)
@@ -0,0 +1,114 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ import mindspore as ms
18
+ from mindspore.ops.primitive import Primitive
19
+
20
+ from msprobe.core.common.utils import Const
21
+ from msprobe.core.service import BaseService
22
+ from msprobe.mindspore.cell_processor import CellProcessor
23
+ from msprobe.mindspore.common.log import logger
24
+ from msprobe.mindspore.common.utils import (
25
+ get_rank_if_initialized,
26
+ is_mindtorch,
27
+ get_cells_and_names_with_index
28
+ )
29
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate
30
+ from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager
31
+ from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
32
+ from msprobe.mindspore.dump.jit_dump import JitDump
33
+
34
+ try:
35
+ from mindspore.common._pijit_context import PIJitCaptureContext
36
+ except ImportError:
37
+ pijit_label = False
38
+ else:
39
+ pijit_label = True
40
+
41
+
42
+ class MindsporeService(BaseService):
43
+ @property
44
+ def _get_framework_type(self):
45
+ return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
46
+
47
+ @staticmethod
48
+ def _get_current_rank():
49
+ return get_rank_if_initialized()
50
+
51
+ def empty(self, *args, **kwargs):
52
+ pass
53
+
54
+ def reset_status(self):
55
+ self._reset_status()
56
+
57
+ def _init_specific_components(self):
58
+ self.logger = logger
59
+ self.api_register = get_api_register()
60
+ self.primitive_hook_service = PrimitiveHookService(self)
61
+ self.cell_processor = CellProcessor(self.data_collector.scope)
62
+ self.hook_manager = MindsproeHookManager(self.data_collector, self.config)
63
+ self._setup_jit_context()
64
+ self.api_template = ApiTemplate
65
+
66
+ def _setup_jit_context(self):
67
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
68
+ JitDump.set_config(self.config)
69
+ JitDump.set_data_collector(self.data_collector)
70
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
71
+ ms.common.api._MindsporeFunctionExecutor = JitDump
72
+ else:
73
+ ms.common.api._JitExecutor = JitDump
74
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
75
+ if pijit_label:
76
+ PIJitCaptureContext.__enter__ = self.empty
77
+ PIJitCaptureContext.__exit__ = self.empty
78
+
79
+ def _register_module_hook(self):
80
+ self.cell_processor.register_cell_hook(self.model, self.build_hook, self.config)
81
+ self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
82
+
83
+ def _register_hook(self):
84
+ self._register_primitive_hook()
85
+
86
+ def _register_primitive_hook(self):
87
+ if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
88
+ return
89
+ if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
90
+ return
91
+
92
+ primitive_set = set()
93
+ cells_and_names_with_index, _ = get_cells_and_names_with_index(self.model)
94
+ for cells_and_names in cells_and_names_with_index.values():
95
+ for _, cell in cells_and_names:
96
+ for attribute, value in vars(cell).items():
97
+ if isinstance(value, Primitive):
98
+ primitive_set.add((attribute, value))
99
+
100
+ for pname, primitive in primitive_set:
101
+ primitive_class_name = primitive.__class__.__name__
102
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
103
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
104
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
105
+ primitive_combined_name)})
106
+ primitive.__class__ = new_primitive
107
+
108
+ def _reset_status(self):
109
+ super()._reset_status()
110
+ self.primitive_hook_service.primitive_counters.clear()
111
+ JitDump.jit_count = defaultdict(int)
112
+
113
+ def _change_jit_switch(self, status):
114
+ JitDump.jit_dump_switch = status
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from mindspore import nn
18
+ from mindspore import communication
19
+ from msprobe.mindspore.monitor.utils import logger
20
+ from msprobe.mindspore.common.utils import is_mindtorch
21
+ if is_mindtorch():
22
+ import torch
23
+
24
+
25
+ def is_valid_instance(model):
26
+ return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell)
27
+
28
+
29
+ def get_submodules(model):
30
+ if not is_valid_instance(model):
31
+ logger.info("Counter invalid model, nothing to hook")
32
+ return {}
33
+ return model.named_modules() if is_mindtorch() else model.cells_and_names()
34
+
35
+
36
+ def get_parameters(model):
37
+ if not is_valid_instance(model):
38
+ return {}
39
+ if is_mindtorch():
40
+ return model.named_parameters()
41
+ else:
42
+ return model.parameters_and_names()
43
+
44
+
45
+ def get_rank():
46
+ if comm_is_initialized():
47
+ return communication.get_rank()
48
+ return 0
49
+
50
+
51
+ def comm_is_initialized():
52
+ return communication.GlobalComm.INITED